use std::collections::BTreeMap;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use serde_json::Value;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LtxKeyFormat {
Native,
Diffusers,
}
#[derive(Debug, Clone)]
pub struct LtxVideoSingleFileBundle {
pub format: LtxKeyFormat,
#[allow(dead_code)]
pub transformer_key_count: usize,
pub has_vae: bool,
}
#[derive(Debug, Error)]
pub enum LoadError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("safetensors header parse failed: {0}")]
Header(String),
#[error(
"no LTX-Video transformer keys found \
(expected `transformer_blocks.*` or `model.diffusion_model.transformer_blocks.*`)"
)]
NoTransformerKeys,
}
pub fn load(path: &Path) -> Result<LtxVideoSingleFileBundle, LoadError> {
let keys = read_tensor_keys(path)?;
let mut native_count = 0usize;
let mut diffusers_count = 0usize;
let mut vae_count = 0usize;
for key in &keys {
if has_prefix(key, NATIVE_TRANSFORMER_KEY) {
native_count += 1;
} else if has_prefix(key, DIFFUSERS_TRANSFORMER_KEY) {
diffusers_count += 1;
} else if has_prefix(key, VAE_KEY) {
vae_count += 1;
}
}
let (format, transformer_key_count) = if diffusers_count > 0 {
(LtxKeyFormat::Diffusers, diffusers_count)
} else if native_count > 0 {
(LtxKeyFormat::Native, native_count)
} else {
return Err(LoadError::NoTransformerKeys);
};
Ok(LtxVideoSingleFileBundle {
format,
transformer_key_count,
has_vae: vae_count > 0,
})
}
const NATIVE_TRANSFORMER_KEY: &str = "transformer_blocks";
const DIFFUSERS_TRANSFORMER_KEY: &str = "model.diffusion_model.transformer_blocks";
const VAE_KEY: &str = "vae";
fn has_prefix(key: &str, prefix: &str) -> bool {
if key.len() < prefix.len() {
return false;
}
if key == prefix {
return true;
}
key.as_bytes().get(prefix.len()) == Some(&b'.') && key.starts_with(prefix)
}
fn read_tensor_keys(path: &Path) -> Result<Vec<String>, LoadError> {
let mut file = File::open(path)?;
let mut len_buf = [0u8; 8];
file.read_exact(&mut len_buf)?;
let header_len = u64::from_le_bytes(len_buf) as usize;
let mut header_buf = vec![0u8; header_len];
file.read_exact(&mut header_buf)?;
let header: BTreeMap<String, Value> =
serde_json::from_slice(&header_buf).map_err(|e| LoadError::Header(e.to_string()))?;
Ok(header.into_keys().filter(|k| k != "__metadata__").collect())
}
#[cfg(test)]
mod tests {
use super::*;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::path::PathBuf;
fn temp_path(tag: &str) -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!(
"mold-ltxvideo-sf-{}-{}-{}.safetensors",
tag,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
p
}
fn write_fixture(path: &Path, keys: &[&str]) {
let zero = 0.0f32.to_le_bytes().to_vec();
let bufs: Vec<Vec<u8>> = keys.iter().map(|_| zero.clone()).collect();
let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
for (key, buf) in keys.iter().zip(bufs.iter()) {
tensors.insert(
key.to_string(),
TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
);
}
serialize_to_file(&tensors, &None, path).unwrap();
}
#[test]
fn native_transformer_only_checkpoint() {
let p = temp_path("native-no-vae");
write_fixture(
&p,
&[
"transformer_blocks.0.attn1.to_q.weight",
"transformer_blocks.0.attn1.to_k.weight",
"proj_in.weight",
"caption_projection.linear_1.weight",
],
);
let bundle = load(&p).expect("native transformer-only load");
assert_eq!(bundle.format, LtxKeyFormat::Native);
assert_eq!(bundle.transformer_key_count, 2); assert!(!bundle.has_vae);
let _ = std::fs::remove_file(p);
}
#[test]
fn native_combined_transformer_and_vae() {
let p = temp_path("native-with-vae");
write_fixture(
&p,
&[
"transformer_blocks.0.attn1.to_q.weight",
"proj_in.weight",
"vae.encoder.conv_in.weight",
"vae.decoder.conv_out.weight",
],
);
let bundle = load(&p).expect("native combined load");
assert_eq!(bundle.format, LtxKeyFormat::Native);
assert!(bundle.has_vae);
let _ = std::fs::remove_file(p);
}
#[test]
fn diffusers_format_detected() {
let p = temp_path("diffusers");
write_fixture(
&p,
&[
"model.diffusion_model.transformer_blocks.0.attn1.to_q.weight",
"model.diffusion_model.patchify_proj.weight",
],
);
let bundle = load(&p).expect("diffusers load");
assert_eq!(bundle.format, LtxKeyFormat::Diffusers);
assert_eq!(bundle.transformer_key_count, 1); assert!(!bundle.has_vae);
let _ = std::fs::remove_file(p);
}
#[test]
fn diffusers_takes_precedence_over_native() {
let p = temp_path("both");
write_fixture(
&p,
&[
"model.diffusion_model.transformer_blocks.0.to_q.weight",
"transformer_blocks.0.to_q.weight",
],
);
let bundle = load(&p).expect("both-format load");
assert_eq!(bundle.format, LtxKeyFormat::Diffusers);
let _ = std::fs::remove_file(p);
}
#[test]
fn no_transformer_keys_returns_error() {
let p = temp_path("no-transformer");
write_fixture(&p, &["vae.encoder.conv_in.weight", "some.other.weight"]);
assert!(matches!(load(&p), Err(LoadError::NoTransformerKeys)));
let _ = std::fs::remove_file(p);
}
#[test]
fn has_prefix_is_exact_segment_match() {
assert!(!has_prefix("transformer_blocks.0.to_q.weight", "blocks"));
assert!(!has_prefix(
"model.diffusion_model.transformer_blocks.0.to_q.weight",
"model.diffusion_model.blocks"
));
assert!(has_prefix(
"transformer_blocks.0.to_q.weight",
"transformer_blocks"
));
assert!(has_prefix("vae", "vae"));
}
}