use std::collections::BTreeMap;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use serde_json::Value;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LtxKeyFormat {
Native,
Diffusers,
}
#[derive(Debug, Clone)]
pub struct Ltx2SingleFileBundle {
#[allow(dead_code)]
pub format: LtxKeyFormat,
#[allow(dead_code)]
pub transformer_key_count: usize,
pub has_vae: bool,
#[allow(dead_code)]
pub has_nvfp4: bool,
pub model_version: Option<String>,
}
#[derive(Debug, Error)]
pub enum LoadError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("safetensors header parse failed: {0}")]
Header(String),
#[error(
"no LTX-2 transformer keys found \
(expected `transformer_blocks.*` or `model.diffusion_model.transformer_blocks.*`)"
)]
NoTransformerKeys,
}
pub fn load(path: &Path) -> Result<Ltx2SingleFileBundle, LoadError> {
let header = read_header(path)?;
let mut native_count = 0usize;
let mut diffusers_count = 0usize;
let mut vae_count = 0usize;
let mut nvfp4 = false;
for key in header.keys() {
if key == "__metadata__" {
continue;
}
if key.ends_with(".weight_scale_2") || key.ends_with(".comfy_quant") {
nvfp4 = true;
continue;
}
if key.ends_with(".weight_scale") || key.ends_with(".input_scale") {
continue;
}
if has_prefix(key, NATIVE_TRANSFORMER_KEY) {
native_count += 1;
} else if has_prefix(key, DIFFUSERS_TRANSFORMER_KEY) {
diffusers_count += 1;
} else if has_prefix(key, VAE_KEY) {
vae_count += 1;
}
}
let (format, transformer_key_count) = if diffusers_count > 0 {
(LtxKeyFormat::Diffusers, diffusers_count)
} else if native_count > 0 {
(LtxKeyFormat::Native, native_count)
} else {
return Err(LoadError::NoTransformerKeys);
};
let model_version = header
.get("__metadata__")
.and_then(|v| v.get("model_version"))
.and_then(|v| v.as_str())
.map(str::to_owned);
Ok(Ltx2SingleFileBundle {
format,
transformer_key_count,
has_vae: vae_count > 0,
has_nvfp4: nvfp4,
model_version,
})
}
const NATIVE_TRANSFORMER_KEY: &str = "transformer_blocks";
const DIFFUSERS_TRANSFORMER_KEY: &str = "model.diffusion_model.transformer_blocks";
const VAE_KEY: &str = "vae";
fn has_prefix(key: &str, prefix: &str) -> bool {
if key.len() < prefix.len() {
return false;
}
if key == prefix {
return true;
}
key.as_bytes().get(prefix.len()) == Some(&b'.') && key.starts_with(prefix)
}
fn read_header(path: &Path) -> Result<BTreeMap<String, Value>, LoadError> {
let mut file = File::open(path)?;
let mut len_buf = [0u8; 8];
file.read_exact(&mut len_buf)?;
let header_len = u64::from_le_bytes(len_buf) as usize;
let mut header_buf = vec![0u8; header_len];
file.read_exact(&mut header_buf)?;
serde_json::from_slice(&header_buf).map_err(|e| LoadError::Header(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::path::PathBuf;
fn temp_path(tag: &str) -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!(
"mold-ltx2-sf-{}-{}-{}.safetensors",
tag,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
p
}
fn write_fixture(path: &Path, keys: &[&str]) {
let zero = 0.0f32.to_le_bytes().to_vec();
let bufs: Vec<Vec<u8>> = keys.iter().map(|_| zero.clone()).collect();
let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
for (key, buf) in keys.iter().zip(bufs.iter()) {
tensors.insert(
key.to_string(),
TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
);
}
serialize_to_file(&tensors, &None, path).unwrap();
}
#[test]
fn native_combined_checkpoint() {
let p = temp_path("native-combined");
write_fixture(
&p,
&[
"transformer_blocks.0.attn1.to_q.weight",
"transformer_blocks.0.attn1.to_k.weight",
"proj_in.weight",
"vae.encoder.conv_in.weight",
"vae.decoder.conv_out.weight",
],
);
let bundle = load(&p).expect("native combined load");
assert_eq!(bundle.format, LtxKeyFormat::Native);
assert_eq!(bundle.transformer_key_count, 2); assert!(bundle.has_vae);
let _ = std::fs::remove_file(p);
}
#[test]
fn native_transformer_only_rejected_by_vae_check() {
let p = temp_path("native-no-vae");
write_fixture(
&p,
&["transformer_blocks.0.attn1.to_q.weight", "proj_in.weight"],
);
let bundle = load(&p).expect("native transformer-only parses");
assert_eq!(bundle.format, LtxKeyFormat::Native);
assert!(
!bundle.has_vae,
"transformer-only checkpoint must report no VAE"
);
let _ = std::fs::remove_file(p);
}
#[test]
fn diffusers_combined_checkpoint() {
let p = temp_path("diffusers-combined");
write_fixture(
&p,
&[
"model.diffusion_model.transformer_blocks.0.attn1.to_q.weight",
"model.diffusion_model.patchify_proj.weight",
"vae.encoder.conv_in.weight",
],
);
let bundle = load(&p).expect("diffusers combined load");
assert_eq!(bundle.format, LtxKeyFormat::Diffusers);
assert!(bundle.has_vae);
let _ = std::fs::remove_file(p);
}
#[test]
fn nvfp4_checkpoint_parses_as_supported_transformer() {
let p = temp_path("nvfp4");
write_fixture(
&p,
&[
"model.diffusion_model.transformer_blocks.0.attn1.to_q.weight",
"model.diffusion_model.transformer_blocks.0.attn1.to_q.weight_scale",
"model.diffusion_model.transformer_blocks.0.attn1.to_q.weight_scale_2",
"vae.encoder.conv_in.weight",
],
);
let bundle = load(&p).expect("NVFP4 LTX-2 checkpoint should parse");
assert_eq!(bundle.format, LtxKeyFormat::Diffusers);
assert_eq!(bundle.transformer_key_count, 1);
assert!(bundle.has_nvfp4);
assert!(bundle.has_vae);
let _ = std::fs::remove_file(p);
}
#[test]
fn no_transformer_keys_returns_error() {
let p = temp_path("no-transformer");
write_fixture(&p, &["vae.encoder.conv_in.weight"]);
assert!(matches!(load(&p), Err(LoadError::NoTransformerKeys)));
let _ = std::fs::remove_file(p);
}
}