use crate::engine::LoadStrategy;
use crate::progress::ProgressReporter;
use mold_core::types::GpuSelection;
use std::cell::Cell;
thread_local! {
static THREAD_GPU_ORDINAL: Cell<Option<usize>> = const { Cell::new(None) };
}
pub fn init_thread_gpu_ordinal(ordinal: usize) {
THREAD_GPU_ORDINAL.with(|c| c.set(Some(ordinal)));
}
pub fn clear_thread_gpu_ordinal() {
THREAD_GPU_ORDINAL.with(|c| c.set(None));
}
pub fn thread_gpu_ordinal() -> Option<usize> {
THREAD_GPU_ORDINAL.with(|c| c.get())
}
#[inline]
fn debug_assert_ordinal_matches_thread(ordinal: usize, context: &'static str) {
if cfg!(debug_assertions) {
if let Some(expected) = thread_gpu_ordinal() {
assert_eq!(
expected, ordinal,
"{context}: ordinal {ordinal} does not match this thread's \
bound GPU {expected} — hardcoded ordinal regression?"
);
}
}
}
#[derive(Debug, Clone)]
pub struct DiscoveredGpu {
pub ordinal: usize,
pub name: String,
pub total_vram_bytes: u64,
pub free_vram_bytes: u64,
}
pub fn discover_gpus() -> Vec<DiscoveredGpu> {
let mut gpus = Vec::new();
#[cfg(feature = "cuda")]
{
use candle_core::cuda_backend::cudarc::driver;
if candle_core::utils::cuda_is_available() {
match driver::CudaContext::device_count() {
Ok(count) => {
for ordinal in 0..count as usize {
match driver::CudaContext::new(ordinal) {
Ok(ctx) => {
let name = ctx
.name()
.unwrap_or_else(|_| format!("CUDA Device {ordinal}"));
let (free, total) =
driver::result::mem_get_info().unwrap_or((0, 0));
gpus.push(DiscoveredGpu {
ordinal,
name,
total_vram_bytes: total as u64,
free_vram_bytes: free as u64,
});
}
Err(e) => tracing::warn!("failed to open CUDA device {ordinal}: {e}"),
}
}
}
Err(e) => tracing::warn!("CUDA device count failed: {e}"),
}
}
}
#[cfg(not(feature = "cuda"))]
{
if candle_core::utils::metal_is_available() {
let total = available_system_memory_bytes().unwrap_or(0);
let free = free_system_memory_bytes().unwrap_or(0);
gpus.push(DiscoveredGpu {
ordinal: 0,
name: "Apple Metal GPU".to_string(),
total_vram_bytes: total,
free_vram_bytes: free,
});
}
}
gpus
}
pub fn filter_gpus(gpus: &[DiscoveredGpu], selection: &GpuSelection) -> Vec<DiscoveredGpu> {
match selection {
GpuSelection::All => gpus.to_vec(),
GpuSelection::Specific(ordinals) => gpus
.iter()
.filter(|g| ordinals.contains(&g.ordinal))
.cloned()
.collect(),
}
}
pub fn select_best_gpu(gpus: &[DiscoveredGpu]) -> Option<&DiscoveredGpu> {
gpus.iter().max_by_key(|g| g.free_vram_bytes)
}
pub fn create_device(
ordinal: usize,
progress: &ProgressReporter,
) -> anyhow::Result<candle_core::Device> {
use candle_core::Device;
let force_cpu = std::env::var("MOLD_DEVICE")
.map(|v| v.eq_ignore_ascii_case("cpu"))
.unwrap_or(false);
if force_cpu {
progress.info("CPU forced via MOLD_DEVICE=cpu");
tracing::info!("CPU forced via MOLD_DEVICE=cpu");
return Ok(Device::Cpu);
}
debug_assert_ordinal_matches_thread(ordinal, "create_device");
if candle_core::utils::cuda_is_available() {
progress.info(&format!("Using CUDA device {ordinal}"));
tracing::info!("Using CUDA device {ordinal}");
Ok(Device::new_cuda(ordinal)?)
} else if candle_core::utils::metal_is_available() {
progress.info(&format!("Using Metal device {ordinal}"));
tracing::info!("Using Metal device {ordinal}");
Ok(Device::new_metal(ordinal)?)
} else {
progress.info("No GPU detected, using CPU");
tracing::warn!("No GPU detected, falling back to CPU");
Ok(Device::Cpu)
}
}
pub const T5_ACTIVATION_HEADROOM: u64 = 2_000_000_000;
pub fn t5_vram_threshold(model_size_bytes: u64) -> u64 {
model_size_bytes + T5_ACTIVATION_HEADROOM
}
pub const T5_VRAM_THRESHOLD: u64 = 16_000_000_000;
pub const CLIP_VRAM_THRESHOLD: u64 = 800_000_000;
pub const CLIPG_VRAM_THRESHOLD: u64 = 2_800_000_000;
pub fn qwen3_vram_threshold(model_size_bytes: u64) -> u64 {
model_size_bytes + T5_ACTIVATION_HEADROOM
}
pub fn qwen2_vram_threshold(model_size_bytes: u64) -> u64 {
model_size_bytes + T5_ACTIVATION_HEADROOM
}
pub const EXPAND_ACTIVATION_HEADROOM: u64 = 2_000_000_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ActivationFamily {
FluxDit,
Flux2Dit,
Sd3Mmdit,
SdxlUnet,
QwenImageDit,
ZImageDit,
Wuerstchen,
SmallTransformer,
LtxVideo,
Ltx2Video,
}
impl ActivationFamily {
pub fn streaming_transformer(self) -> bool {
matches!(self, ActivationFamily::Ltx2Video)
}
pub fn is_full_weight_video(self) -> bool {
matches!(self, ActivationFamily::LtxVideo)
}
}
pub fn activation_bytes(
width: u32,
height: u32,
batch: u32,
dtype_bytes: u32,
family: ActivationFamily,
) -> u64 {
let area = (width as u64).saturating_mul(height as u64);
let bytes_per_pixel = (dtype_bytes as u64).saturating_mul(batch.max(1) as u64);
let factor: f64 = match family {
ActivationFamily::FluxDit => 130.0,
ActivationFamily::Flux2Dit => 130.0,
ActivationFamily::ZImageDit => 130.0,
ActivationFamily::Sd3Mmdit => 156.0,
ActivationFamily::SdxlUnet => 173.0,
ActivationFamily::QwenImageDit => 173.0,
ActivationFamily::Wuerstchen => 217.0,
ActivationFamily::SmallTransformer => 87.0,
ActivationFamily::LtxVideo => 130.0,
ActivationFamily::Ltx2Video => 130.0,
};
let raw = (area as f64 * bytes_per_pixel as f64 * factor) as u64;
const ACTIVATION_FLOOR_BYTES: u64 = 256_000_000;
raw.max(ACTIVATION_FLOOR_BYTES)
}
pub fn dtype_bytes(dt: candle_core::DType) -> u32 {
use candle_core::DType;
match dt {
DType::BF16 | DType::F16 => 2,
DType::F32 => 4,
DType::F64 => 8,
_ => 2,
}
}
pub fn activation_family_for(family_slug: &str) -> ActivationFamily {
match family_slug {
"flux" => ActivationFamily::FluxDit,
"flux2" => ActivationFamily::Flux2Dit,
"sd3" => ActivationFamily::Sd3Mmdit,
"sdxl" | "sd15" => ActivationFamily::SdxlUnet,
"qwen-image" | "qwen-image-edit" => ActivationFamily::QwenImageDit,
"z-image" => ActivationFamily::ZImageDit,
"wuerstchen" => ActivationFamily::Wuerstchen,
"ltx-video" => ActivationFamily::LtxVideo,
"ltx2" | "ltx-2" | "ltx-2.3" => ActivationFamily::Ltx2Video,
_ => ActivationFamily::FluxDit,
}
}
pub fn expand_vram_threshold(model_size_bytes: u64) -> u64 {
model_size_bytes + EXPAND_ACTIVATION_HEADROOM
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpandPlacement {
Gpu(usize),
Cpu,
}
pub fn select_expand_device(
gpus: &[DiscoveredGpu],
threshold: u64,
is_metal: bool,
) -> ExpandPlacement {
select_expand_device_with_preference(gpus, threshold, is_metal, None)
}
pub fn select_expand_device_with_preference(
gpus: &[DiscoveredGpu],
threshold: u64,
is_metal: bool,
preferred_ordinal: Option<usize>,
) -> ExpandPlacement {
if is_metal {
if let Some(ordinal) = preferred_ordinal {
if let Some(g) = gpus.iter().find(|g| g.ordinal == ordinal) {
return ExpandPlacement::Gpu(g.ordinal);
}
}
if let Some(g) = gpus.first() {
return ExpandPlacement::Gpu(g.ordinal);
}
return ExpandPlacement::Cpu;
}
if let Some(ordinal) = preferred_ordinal {
if let Some(g) = gpus
.iter()
.find(|g| g.ordinal == ordinal && g.free_vram_bytes > threshold)
{
return ExpandPlacement::Gpu(g.ordinal);
}
}
for g in gpus {
if g.free_vram_bytes > threshold {
return ExpandPlacement::Gpu(g.ordinal);
}
}
ExpandPlacement::Cpu
}
pub const LTX2_GEMMA_VRAM_THRESHOLD: u64 = 24_000_000_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LtxGemmaPlacement {
Gpu(usize),
Cpu,
}
impl LtxGemmaPlacement {
pub fn into_device(self) -> candle_core::Device {
match self {
LtxGemmaPlacement::Gpu(ordinal) => match candle_core::Device::new_cuda(ordinal) {
Ok(d) => d,
Err(err) => {
tracing::warn!(
ordinal,
error = %err,
"failed to open CUDA device for LTX-2 Gemma encoder, falling back to CPU"
);
candle_core::Device::Cpu
}
},
LtxGemmaPlacement::Cpu => candle_core::Device::Cpu,
}
}
}
pub fn select_ltx2_gemma_device(
gpus: &[DiscoveredGpu],
active_ordinal: usize,
threshold: u64,
) -> LtxGemmaPlacement {
if let Some(g) = gpus
.iter()
.find(|g| g.ordinal == active_ordinal && g.free_vram_bytes > threshold)
{
return LtxGemmaPlacement::Gpu(g.ordinal);
}
for g in gpus {
if g.ordinal == active_ordinal {
continue;
}
if g.free_vram_bytes > threshold {
return LtxGemmaPlacement::Gpu(g.ordinal);
}
}
LtxGemmaPlacement::Cpu
}
pub fn resolve_ltx2_gemma_device_override(gpu_ordinal: usize) -> Option<LtxGemmaPlacement> {
if let Ok(raw) = std::env::var("MOLD_LTX2_GEMMA_DEVICE") {
let trimmed = raw.trim();
if !trimmed.is_empty() {
let lower = trimmed.to_ascii_lowercase();
match lower.as_str() {
"cpu" => return Some(LtxGemmaPlacement::Cpu),
"gpu" => return Some(LtxGemmaPlacement::Gpu(gpu_ordinal)),
"auto" => return None,
_ => {
tracing::warn!(
value = %trimmed,
"unrecognised MOLD_LTX2_GEMMA_DEVICE value; expected cpu/gpu/auto",
);
return None;
}
}
}
}
if std::env::var_os("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER").is_some() {
warn_once_legacy_force_cpu_prompt_encoder();
return Some(LtxGemmaPlacement::Cpu);
}
None
}
fn warn_once_legacy_force_cpu_prompt_encoder() {
use std::sync::OnceLock;
static WARNED: OnceLock<()> = OnceLock::new();
WARNED.get_or_init(|| {
tracing::warn!(
"MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER is deprecated; \
use MOLD_LTX2_GEMMA_DEVICE=cpu instead",
);
});
}
pub fn resolve_ltx2_gemma_placement(gpu_ordinal: usize) -> LtxGemmaPlacement {
if let Some(p) = resolve_ltx2_gemma_device_override(gpu_ordinal) {
return p;
}
let gpus = discover_gpus();
select_ltx2_gemma_device(&gpus, gpu_ordinal, LTX2_GEMMA_VRAM_THRESHOLD)
}
pub const QWEN3_FP16_VRAM_THRESHOLD: u64 = 10_200_000_000;
const MEMORY_BUDGET_HEADROOM: u64 = 2_000_000_000;
pub fn resolve_device<F>(
req: Option<mold_core::types::DeviceRef>,
auto: F,
) -> anyhow::Result<candle_core::Device>
where
F: FnOnce() -> anyhow::Result<candle_core::Device>,
{
use mold_core::types::DeviceRef;
match req {
None | Some(DeviceRef::Auto) => auto(),
Some(DeviceRef::Cpu) => Ok(candle_core::Device::Cpu),
Some(DeviceRef::Gpu { ordinal }) => resolve_gpu_ordinal(ordinal),
}
}
#[cfg(feature = "cuda")]
fn resolve_gpu_ordinal(ordinal: usize) -> anyhow::Result<candle_core::Device> {
debug_assert_ordinal_matches_thread(ordinal, "resolve_device");
candle_core::Device::new_cuda(ordinal)
.map_err(|e| anyhow::anyhow!("failed to open CUDA device {ordinal}: {e}"))
}
#[cfg(all(not(feature = "cuda"), feature = "metal"))]
fn resolve_gpu_ordinal(ordinal: usize) -> anyhow::Result<candle_core::Device> {
debug_assert_ordinal_matches_thread(ordinal, "resolve_device");
candle_core::Device::new_metal(ordinal)
.map_err(|e| anyhow::anyhow!("failed to open Metal device {ordinal}: {e}"))
}
#[cfg(all(not(feature = "cuda"), not(feature = "metal")))]
fn resolve_gpu_ordinal(ordinal: usize) -> anyhow::Result<candle_core::Device> {
Err(anyhow::anyhow!(
"GPU ordinal {ordinal} requested but this build has neither CUDA nor Metal enabled"
))
}
pub fn effective_device_ref(
placement: Option<&mold_core::types::DevicePlacement>,
advanced_override: impl FnOnce(
&mold_core::types::AdvancedPlacement,
) -> Option<mold_core::types::DeviceRef>,
fallback_is_component_auto: bool,
) -> mold_core::types::DeviceRef {
use mold_core::types::DeviceRef;
let Some(placement) = placement else {
return DeviceRef::Auto;
};
if let Some(adv) = placement.advanced.as_ref() {
if let Some(r) = advanced_override(adv) {
return r;
}
if fallback_is_component_auto {
return placement.text_encoders;
}
DeviceRef::Auto
} else {
placement.text_encoders
}
}
#[cfg(target_os = "macos")]
struct MacOSMemInfo {
free: u64,
inactive: u64,
}
#[cfg(target_os = "macos")]
fn macos_vm_stats() -> Option<MacOSMemInfo> {
type MachPort = u32;
type KernReturn = i32;
type HostFlavor = i32;
type MachMsgType = u32;
const HOST_VM_INFO64: HostFlavor = 4;
const HOST_VM_INFO64_COUNT: MachMsgType = 38;
const KERN_SUCCESS: KernReturn = 0;
extern "C" {
fn mach_host_self() -> MachPort;
fn host_statistics64(
host: MachPort,
flavor: HostFlavor,
info: *mut i32,
count: *mut MachMsgType,
) -> KernReturn;
fn host_page_size(host: MachPort, page_size: *mut usize) -> KernReturn;
}
unsafe {
let mut buf = [0i32; HOST_VM_INFO64_COUNT as usize];
let mut count = HOST_VM_INFO64_COUNT;
let ret = host_statistics64(
mach_host_self(),
HOST_VM_INFO64,
buf.as_mut_ptr(),
&mut count,
);
if ret != KERN_SUCCESS {
return None;
}
let mut page_size: usize = 0;
let ret = host_page_size(mach_host_self(), &mut page_size);
if ret != KERN_SUCCESS {
return None;
}
let page_size = page_size as u64;
Some(MacOSMemInfo {
free: buf[0] as u32 as u64 * page_size,
inactive: buf[2] as u32 as u64 * page_size,
})
}
}
#[cfg(target_os = "macos")]
pub fn free_system_memory_bytes() -> Option<u64> {
macos_vm_stats().map(|s| s.free)
}
#[cfg(target_os = "macos")]
pub fn available_system_memory_bytes() -> Option<u64> {
macos_vm_stats().map(|s| s.free + s.inactive)
}
#[cfg(not(target_os = "macos"))]
pub fn free_system_memory_bytes() -> Option<u64> {
None
}
#[cfg(not(target_os = "macos"))]
pub fn available_system_memory_bytes() -> Option<u64> {
None
}
pub fn keep_te_in_ram() -> bool {
std::env::var("MOLD_KEEP_TE_RAM")
.map(|v| v == "1")
.unwrap_or(false)
}
#[cfg(feature = "cuda")]
pub fn reclaim_gpu_memory(ordinal: usize) {
use candle_core::cuda_backend::cudarc::driver::{result, sys};
debug_assert_ordinal_matches_thread(ordinal, "reclaim_gpu_memory");
let _ = result::ctx::synchronize();
let cu_device = match result::device::get(ordinal as i32) {
Ok(d) => d,
Err(e) => {
tracing::warn!("reclaim_gpu_memory: failed to get device {ordinal}: {e}");
return;
}
};
let result = unsafe { sys::cuDevicePrimaryCtxReset_v2(cu_device) };
if result != sys::CUresult::CUDA_SUCCESS {
tracing::warn!(
"reclaim_gpu_memory: cuDevicePrimaryCtxReset for device {ordinal} returned {result:?}"
);
} else {
tracing::info!("CUDA primary context reset for device {ordinal}, GPU memory reclaimed");
}
}
#[cfg(not(feature = "cuda"))]
pub fn reclaim_gpu_memory(_ordinal: usize) {}
#[cfg(feature = "cuda")]
pub fn try_synchronize_device(_ordinal: usize) {
use candle_core::cuda_backend::cudarc::driver::result;
let _ = result::ctx::synchronize();
}
#[cfg(not(feature = "cuda"))]
pub fn try_synchronize_device(_ordinal: usize) {}
#[cfg(feature = "cuda")]
pub fn free_vram_bytes(ordinal: usize) -> Option<u64> {
if candle_core::cuda_backend::cudarc::driver::CudaContext::new(ordinal).is_ok() {
candle_core::cuda_backend::cudarc::driver::result::mem_get_info()
.ok()
.map(|(free, _total)| free as u64)
} else {
None
}
}
#[cfg(not(feature = "cuda"))]
pub fn free_vram_bytes(_ordinal: usize) -> Option<u64> {
available_system_memory_bytes().or_else(free_system_memory_bytes)
}
pub fn reserved_vram_bytes() -> u64 {
if let Ok(s) = std::env::var("MOLD_RESERVE_VRAM_MB") {
if let Ok(mb) = s.parse::<u64>() {
return mb.saturating_mul(1_000_000);
}
}
#[cfg(target_os = "linux")]
{
400_000_000
}
#[cfg(target_os = "windows")]
{
600_000_000
}
#[cfg(target_os = "macos")]
{
0
}
#[cfg(not(any(target_os = "linux", target_os = "windows", target_os = "macos")))]
{
400_000_000
}
}
pub fn usable_free_vram_bytes(ordinal: usize) -> Option<u64> {
let reserve = reserved_vram_bytes();
free_vram_bytes(ordinal).map(|free| usable_free_vram_from_raw(free, reserve))
}
fn usable_free_vram_from_raw(free: u64, reserve: u64) -> u64 {
free.saturating_sub(reserve)
}
#[cfg(feature = "cuda")]
pub fn vram_in_use_bytes(ordinal: usize) -> u64 {
if candle_core::cuda_backend::cudarc::driver::CudaContext::new(ordinal).is_ok() {
candle_core::cuda_backend::cudarc::driver::result::mem_get_info()
.ok()
.map(|(free, total)| total as u64 - free as u64)
.unwrap_or(0)
} else {
0
}
}
#[cfg(not(feature = "cuda"))]
pub fn vram_in_use_bytes(_ordinal: usize) -> u64 {
0
}
#[cfg(feature = "cuda")]
pub fn total_vram_bytes(ordinal: usize) -> Option<u64> {
if candle_core::cuda_backend::cudarc::driver::CudaContext::new(ordinal).is_ok() {
candle_core::cuda_backend::cudarc::driver::result::mem_get_info()
.ok()
.map(|(_free, total)| total as u64)
} else {
None
}
}
#[cfg(not(feature = "cuda"))]
pub fn total_vram_bytes(_ordinal: usize) -> Option<u64> {
None
}
pub fn vram_load_delta(ordinal: usize, baseline: u64) -> u64 {
vram_in_use_bytes(ordinal).saturating_sub(baseline)
}
pub(crate) fn is_gpu(device: &candle_core::Device) -> bool {
device.is_cuda() || device.is_metal()
}
#[allow(dead_code)]
pub(crate) fn gpu_compute_dtype(device: &candle_core::Device) -> candle_core::DType {
if is_gpu(device) {
candle_core::DType::BF16
} else {
candle_core::DType::F32
}
}
pub(crate) fn gpu_dtype(device: &candle_core::Device) -> candle_core::DType {
if device.is_cuda() {
candle_core::DType::BF16
} else {
candle_core::DType::F32
}
}
pub(crate) fn resolve_vae_dtype(default_dtype: candle_core::DType) -> candle_core::DType {
use candle_core::DType;
match std::env::var("MOLD_VAE_DTYPE")
.ok()
.as_deref()
.map(str::trim)
{
None | Some("") | Some("auto") => default_dtype,
Some("bf16") | Some("BF16") => DType::BF16,
Some("fp16") | Some("f16") | Some("FP16") | Some("F16") => DType::F16,
Some("fp32") | Some("f32") | Some("FP32") | Some("F32") => DType::F32,
Some(other) => {
tracing::warn!(
value = other,
"MOLD_VAE_DTYPE has unrecognised value; expected one of auto/bf16/fp16/fp32 — falling back to default"
);
default_dtype
}
}
}
pub(crate) fn fmt_gb(bytes: u64) -> String {
format!("{:.1} GB", bytes as f64 / 1_000_000_000.0)
}
pub(crate) fn should_use_gpu(
is_cuda: bool,
is_metal: bool,
_free_vram: u64,
_threshold: u64,
) -> bool {
if is_metal {
return true;
}
is_cuda && _free_vram > _threshold
}
pub(crate) const MIN_OFFLOAD_VRAM: u64 = 4_000_000_000; pub(crate) const FULL_RESIDENT_RUNTIME_HEADROOM: u64 = 2_000_000_000;
pub(crate) fn should_offload(transformer_size: u64, free_vram: u64, activation_bytes: u64) -> bool {
let needed = transformer_size
.saturating_add(activation_bytes)
.saturating_add(FULL_RESIDENT_RUNTIME_HEADROOM);
free_vram > 0 && needed > free_vram && free_vram >= MIN_OFFLOAD_VRAM
}
pub(crate) fn fits_in_memory(
is_cuda: bool,
is_metal: bool,
free_vram: u64,
threshold: u64,
) -> bool {
if is_metal {
if free_vram > 0 {
return free_vram > threshold;
}
return true;
}
is_cuda && free_vram > threshold
}
pub fn estimate_peak_memory(paths: &mold_core::ModelPaths, strategy: LoadStrategy) -> u64 {
let file_size = |p: &std::path::Path| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0);
let same_file = |a: &std::path::Path, b: &std::path::Path| -> bool {
a == b
|| std::fs::canonicalize(a)
.ok()
.zip(std::fs::canonicalize(b).ok())
.is_some_and(|(a, b)| a == b)
};
let path_matches_any = |path: &std::path::Path, paths: &[std::path::PathBuf]| -> bool {
paths.iter().any(|candidate| same_file(path, candidate))
};
let transformer_size = if !paths.transformer_shards.is_empty() {
paths.transformer_shards.iter().map(|p| file_size(p)).sum()
} else {
file_size(&paths.transformer)
};
let vae_is_transformer_file = if paths.transformer_shards.is_empty() {
same_file(&paths.transformer, &paths.vae)
} else {
paths
.transformer_shards
.iter()
.any(|shard| same_file(shard, &paths.vae))
};
let vae_is_separate_file = !vae_is_transformer_file;
let vae_size = if vae_is_separate_file {
file_size(&paths.vae)
} else {
0
};
let mut base_component_paths: Vec<std::path::PathBuf> = paths.transformer_shards.to_vec();
if base_component_paths.is_empty() {
base_component_paths.push(paths.transformer.clone());
}
if vae_is_separate_file {
base_component_paths.push(paths.vae.clone());
}
let mut counted_encoder_paths: Vec<std::path::PathBuf> = Vec::new();
let mut encoder_size = |path: &std::path::Path| -> u64 {
if path_matches_any(path, &base_component_paths)
|| path_matches_any(path, &counted_encoder_paths)
{
return 0;
}
counted_encoder_paths.push(path.to_path_buf());
file_size(path)
};
let t5_size = paths
.t5_encoder
.as_ref()
.map(|p| encoder_size(p))
.unwrap_or(0);
let clip_size = paths
.clip_encoder
.as_ref()
.map(|p| encoder_size(p))
.unwrap_or(0);
let clip2_size = paths
.clip_encoder_2
.as_ref()
.map(|p| encoder_size(p))
.unwrap_or(0);
let text_encoder_size: u64 = paths
.text_encoder_files
.iter()
.map(|p| encoder_size(p))
.sum();
let encoder_total = t5_size + clip_size + clip2_size + text_encoder_size;
match strategy {
LoadStrategy::Eager => transformer_size + vae_size + encoder_total + MEMORY_BUDGET_HEADROOM,
LoadStrategy::Sequential => {
let peak_encoder = encoder_total;
let peak_inference = transformer_size + vae_size;
std::cmp::max(peak_encoder, peak_inference) + MEMORY_BUDGET_HEADROOM
}
}
}
pub fn check_memory_budget(
paths: &mold_core::ModelPaths,
strategy: LoadStrategy,
) -> Option<String> {
let available = available_system_memory_bytes()?;
let peak = estimate_peak_memory(paths, strategy);
let threshold = available * 80 / 100;
if peak > threshold {
Some(format!(
"Model needs ~{} but only ~{} available. \
Consider a smaller quantized variant or close other applications.",
fmt_gb(peak),
fmt_gb(available),
))
} else {
None
}
}
pub(crate) fn preflight_memory_check(
component_name: &str,
size_bytes: u64,
activation_bytes: u64,
) -> anyhow::Result<()> {
if std::env::var("MOLD_EAGER").is_ok_and(|v| v == "1") {
return Ok(());
}
let available = match available_system_memory_bytes() {
Some(a) if a > 0 => a,
_ => return Ok(()), };
let free = free_system_memory_bytes();
let total = size_bytes.saturating_add(activation_bytes);
preflight_check_budget(component_name, total, available, free)
}
fn preflight_check_budget(
component_name: &str,
size_bytes: u64,
available: u64,
free: Option<u64>,
) -> anyhow::Result<()> {
if size_bytes > available * 90 / 100 {
anyhow::bail!(
"Not enough memory to load {} ({} needed, {} available).\n\
Close other applications or use a smaller quantized model.",
component_name,
fmt_gb(size_bytes),
fmt_gb(available),
);
}
if let Some(f) = free {
if size_bytes > f * 2 {
tracing::warn!(
"{} ({}) exceeds free memory ({}), will reclaim inactive pages",
component_name,
fmt_gb(size_bytes),
fmt_gb(f),
);
}
}
Ok(())
}
pub fn memory_status_string() -> Option<String> {
#[cfg(feature = "cuda")]
{
if let Some(free) = free_vram_bytes(0) {
return Some(format!("VRAM: {} free", fmt_gb(free)));
}
}
#[cfg(target_os = "macos")]
{
if let Some(stats) = macos_vm_stats() {
let available = stats.free + stats.inactive;
return Some(format!(
"Memory: {} free, {} available",
fmt_gb(stats.free),
fmt_gb(available),
));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fmt_gb_zero() {
assert_eq!(fmt_gb(0), "0.0 GB");
}
#[test]
fn fmt_gb_one_gb() {
assert_eq!(fmt_gb(1_000_000_000), "1.0 GB");
}
#[test]
fn fmt_gb_fractional() {
assert_eq!(fmt_gb(14_600_000_000), "14.6 GB");
}
#[test]
fn fmt_gb_small() {
assert_eq!(fmt_gb(800_000_000), "0.8 GB");
}
#[cfg(target_os = "macos")]
#[test]
fn free_system_memory_returns_positive() {
let mem = free_system_memory_bytes();
assert!(mem.is_some());
assert!(mem.unwrap() > 0, "free system memory should be positive");
}
#[cfg(target_os = "macos")]
#[test]
fn available_includes_inactive() {
let free = free_system_memory_bytes().unwrap();
let available = available_system_memory_bytes().unwrap();
assert!(
available >= free,
"available (free+inactive) should be >= free alone"
);
}
#[test]
fn free_vram_returns_some_on_macos_or_none_on_other() {
let _result = free_vram_bytes(0);
#[cfg(target_os = "macos")]
assert!(_result.is_some(), "macOS should return system memory info");
#[cfg(not(any(target_os = "macos", feature = "cuda")))]
assert_eq!(_result, None);
}
#[cfg(target_os = "macos")]
#[test]
fn free_vram_returns_available_not_just_free_on_macos() {
let vram = free_vram_bytes(0).unwrap();
let available = available_system_memory_bytes().unwrap();
let free = free_system_memory_bytes().unwrap();
assert!(
vram >= free,
"free_vram_bytes ({vram}) should be >= free_system_memory ({free})"
);
let max_drift = 256 * 4096; assert!(
vram.abs_diff(available) < max_drift,
"free_vram_bytes ({vram}) should approximately equal available_system_memory ({available})"
);
}
#[test]
fn metal_always_uses_gpu() {
assert!(should_use_gpu(false, true, 0, T5_VRAM_THRESHOLD));
assert!(should_use_gpu(false, true, 1_000, T5_VRAM_THRESHOLD));
assert!(should_use_gpu(
false,
true,
100_000_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn metal_fits_when_enough_free() {
assert!(fits_in_memory(
false,
true,
20_000_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn metal_does_not_fit_when_free_low() {
assert!(!fits_in_memory(
false,
true,
2_000_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn metal_fits_fallback_when_no_memory_info() {
assert!(fits_in_memory(false, true, 0, T5_VRAM_THRESHOLD));
}
#[test]
fn t5_on_gpu_when_plenty_of_vram() {
assert!(should_use_gpu(
true,
false,
16_700_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn t5_on_cpu_when_q6_on_24gb() {
assert!(!should_use_gpu(
true,
false,
14_600_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn t5_on_cpu_when_q8_on_24gb() {
assert!(!should_use_gpu(
true,
false,
11_700_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn t5_on_cpu_when_bf16_fills_vram() {
assert!(!should_use_gpu(true, false, 700_000_000, T5_VRAM_THRESHOLD));
}
#[test]
fn t5_on_cpu_when_exactly_at_threshold() {
assert!(!should_use_gpu(
true,
false,
T5_VRAM_THRESHOLD,
T5_VRAM_THRESHOLD
));
}
#[test]
fn t5_on_cpu_when_no_gpu() {
assert!(!should_use_gpu(
false,
false,
100_000_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn t5_on_gpu_on_48gb_card() {
assert!(should_use_gpu(
true,
false,
35_700_000_000,
T5_VRAM_THRESHOLD
));
}
#[test]
fn clip_on_gpu_when_vram_available() {
assert!(should_use_gpu(
true,
false,
7_500_000_000,
CLIP_VRAM_THRESHOLD
));
}
#[test]
fn clip_on_gpu_with_minimal_vram() {
assert!(should_use_gpu(
true,
false,
900_000_000,
CLIP_VRAM_THRESHOLD
));
}
#[test]
fn clip_on_cpu_when_vram_tight() {
assert!(!should_use_gpu(
true,
false,
500_000_000,
CLIP_VRAM_THRESHOLD
));
}
#[test]
fn t5_threshold_accounts_for_headroom() {
let threshold = std::hint::black_box(T5_VRAM_THRESHOLD);
assert!(threshold > 9_200_000_000);
assert!(threshold < 25_000_000_000);
}
#[test]
fn clip_threshold_accounts_for_headroom() {
let threshold = std::hint::black_box(CLIP_VRAM_THRESHOLD);
assert!(threshold > 246_000_000);
assert!(threshold < 2_000_000_000);
}
#[test]
fn t5_threshold_for_fp16() {
let threshold = t5_vram_threshold(9_200_000_000);
assert!(threshold > 9_200_000_000);
assert!(threshold <= 16_000_000_000);
}
#[test]
fn t5_threshold_for_q8() {
let threshold = t5_vram_threshold(5_060_000_000);
assert_eq!(threshold, 7_060_000_000);
assert!(should_use_gpu(true, false, 17_000_000_000, threshold));
assert!(should_use_gpu(true, false, 12_000_000_000, threshold));
}
#[test]
fn t5_threshold_for_q5() {
let threshold = t5_vram_threshold(3_390_000_000);
assert_eq!(threshold, 5_390_000_000);
assert!(should_use_gpu(true, false, 12_000_000_000, threshold));
}
#[test]
fn t5_threshold_for_q3() {
let threshold = t5_vram_threshold(2_100_000_000);
assert_eq!(threshold, 4_100_000_000);
}
#[test]
fn qwen3_fp16_threshold_with_drop_and_reload() {
assert_eq!(QWEN3_FP16_VRAM_THRESHOLD, 10_200_000_000);
assert!(should_use_gpu(
true,
false,
17_000_000_000,
QWEN3_FP16_VRAM_THRESHOLD
));
assert!(should_use_gpu(
true,
false,
19_000_000_000,
QWEN3_FP16_VRAM_THRESHOLD
));
}
#[test]
fn qwen3_threshold_for_q8() {
let threshold = qwen3_vram_threshold(4_280_000_000);
assert_eq!(threshold, 6_280_000_000);
assert!(should_use_gpu(true, false, 17_000_000_000, threshold));
}
#[test]
fn qwen3_threshold_for_q3() {
let threshold = qwen3_vram_threshold(2_080_000_000);
assert_eq!(threshold, 4_080_000_000);
assert!(should_use_gpu(true, false, 5_000_000_000, threshold));
}
#[test]
fn qwen2_threshold_for_q6() {
let threshold = qwen2_vram_threshold(6_250_000_000);
assert_eq!(threshold, 8_250_000_000);
assert!(should_use_gpu(true, false, 12_000_000_000, threshold));
}
#[test]
fn qwen3_fp16_does_not_fit_with_bf16_transformer() {
assert!(!should_use_gpu(
true,
false,
400_000_000,
QWEN3_FP16_VRAM_THRESHOLD
));
}
const GB: u64 = 1_000_000_000;
#[test]
fn budget_ok_when_plenty_of_memory() {
let result = preflight_check_budget("UNet", 5 * GB, 20 * GB, Some(10 * GB));
assert!(result.is_ok());
}
#[test]
fn budget_hard_fail_when_exceeds_90pct_available() {
let result = preflight_check_budget("UNet", 19 * GB, 20 * GB, Some(GB));
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Not enough memory"), "got: {msg}");
}
#[test]
fn budget_ok_at_exactly_90pct_available() {
let result = preflight_check_budget("UNet", 18 * GB, 20 * GB, Some(GB));
assert!(result.is_ok());
}
#[test]
fn budget_hard_fail_just_over_90pct() {
let available = 10 * GB;
let size = available * 90 / 100 + 1; let result = preflight_check_budget("Transformer", size, available, Some(0));
assert!(result.is_err());
}
#[test]
fn budget_ok_when_low_free_but_high_available() {
let result = preflight_check_budget("UNet", 5 * GB, 18 * GB, Some(400_000_000));
assert!(result.is_ok());
}
#[test]
fn budget_ok_with_no_free_info() {
let result = preflight_check_budget("UNet", 5 * GB, 20 * GB, None);
assert!(result.is_ok());
}
#[test]
fn budget_hard_fail_with_no_free_info() {
let result = preflight_check_budget("UNet", 19 * GB, 20 * GB, None);
assert!(result.is_err());
}
#[test]
fn budget_ok_small_component() {
let result = preflight_check_budget("CLIP-L", 250_000_000, 16 * GB, Some(8 * GB));
assert!(result.is_ok());
}
#[test]
fn budget_error_message_includes_component_name() {
let result = preflight_check_budget("MyModel", 19 * GB, 20 * GB, Some(GB));
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("MyModel"),
"error should mention component name"
);
}
#[test]
fn budget_error_message_includes_sizes() {
let result = preflight_check_budget("UNet", 19 * GB, 20 * GB, Some(GB));
let msg = result.unwrap_err().to_string();
assert!(msg.contains("19.0 GB"), "should show needed size");
assert!(msg.contains("20.0 GB"), "should show available size");
}
fn flux_1024_activation() -> u64 {
activation_bytes(1024, 1024, 1, 2, ActivationFamily::FluxDit)
}
#[test]
fn offload_when_transformer_exceeds_vram() {
assert!(should_offload(24 * GB, 16 * GB, flux_1024_activation()));
}
#[test]
fn offload_when_transformer_fits_but_no_headroom() {
let xformer = 23_800_000_000;
let free = 24_500_000_000;
assert!(should_offload(xformer, free, flux_1024_activation()));
}
#[test]
fn no_offload_when_plenty_of_vram() {
assert!(!should_offload(12 * GB, 24 * GB, flux_1024_activation()));
}
#[test]
fn no_offload_when_vram_unknown() {
assert!(!should_offload(24 * GB, 0, flux_1024_activation()));
}
#[test]
fn no_offload_when_vram_too_small_for_single_block() {
assert!(!should_offload(24 * GB, 2 * GB, flux_1024_activation()));
}
#[test]
fn activation_bytes_scales_with_area() {
let small = activation_bytes(1024, 1024, 1, 2, ActivationFamily::FluxDit);
let big = activation_bytes(2048, 2048, 1, 2, ActivationFamily::FluxDit);
assert!(
small > 256_000_000,
"1024² FLUX bf16 should clear the floor, got {small}"
);
let ratio = big as f64 / small as f64;
assert!(
(ratio - 4.0).abs() < 0.04,
"expected 4× scaling, got {ratio:.4} (small={small}, big={big})"
);
}
#[test]
fn activation_bytes_scales_with_dtype() {
let bf16 = activation_bytes(2048, 2048, 1, 2, ActivationFamily::FluxDit);
let f32 = activation_bytes(2048, 2048, 1, 4, ActivationFamily::FluxDit);
assert!(bf16 > 256_000_000, "2048² FLUX bf16 should clear floor");
let ratio = f32 as f64 / bf16 as f64;
assert!(
(ratio - 2.0).abs() < 0.02,
"expected f32 = 2× bf16, got {ratio:.4} (bf16={bf16}, f32={f32})"
);
}
#[test]
fn activation_bytes_scales_with_batch() {
let b1 = activation_bytes(2048, 2048, 1, 2, ActivationFamily::SdxlUnet);
let b2 = activation_bytes(2048, 2048, 2, 2, ActivationFamily::SdxlUnet);
assert!(b1 > 256_000_000, "2048² SDXL bf16 b=1 should clear floor");
let ratio = b2 as f64 / b1 as f64;
assert!(
(ratio - 2.0).abs() < 0.02,
"expected b=2 → 2× b=1, got {ratio:.4} (b1={b1}, b2={b2})"
);
}
#[test]
fn activation_bytes_floors_at_256mb() {
let tiny = activation_bytes(64, 64, 1, 2, ActivationFamily::FluxDit);
assert_eq!(
tiny, 256_000_000,
"tiny input must hit the 256 MB floor exactly, got {tiny}"
);
let tiny_te = activation_bytes(64, 64, 1, 2, ActivationFamily::SmallTransformer);
assert_eq!(tiny_te, 256_000_000);
}
#[test]
fn activation_bytes_flux_dit_at_1024_is_in_expected_range() {
let budget = activation_bytes(1024, 1024, 1, 2, ActivationFamily::FluxDit);
assert!(
(200_000_000..=1_000_000_000).contains(&budget),
"FLUX 1024² bf16 cfg=1 budget {budget} bytes outside [200 MB, 1 GB]"
);
}
#[test]
fn should_offload_uses_resolution_scaled_activation() {
let xformer = 22_000_000_000;
let free = 24_500_000_000;
let act_768 = activation_bytes(768, 768, 1, 2, ActivationFamily::FluxDit);
let act_2048 = activation_bytes(2048, 2048, 1, 2, ActivationFamily::FluxDit);
assert!(act_2048 > act_768, "2048² must exceed 768²");
assert!(
!should_offload(xformer, free, act_768),
"small budget must NOT trigger offload at this VRAM (act_768={act_768})"
);
assert!(
should_offload(xformer, free, act_2048),
"large budget MUST trigger offload at this VRAM (act_2048={act_2048})"
);
}
#[test]
fn activation_family_for_maps_known_and_falls_back() {
assert_eq!(activation_family_for("flux"), ActivationFamily::FluxDit);
assert_eq!(activation_family_for("sdxl"), ActivationFamily::SdxlUnet);
assert_eq!(
activation_family_for("qwen-image"),
ActivationFamily::QwenImageDit
);
assert_eq!(
activation_family_for("wuerstchen"),
ActivationFamily::Wuerstchen
);
assert_eq!(
activation_family_for("bogus-family"),
ActivationFamily::FluxDit
);
}
#[test]
fn dtype_bytes_matches_runtime() {
use candle_core::DType;
assert_eq!(dtype_bytes(DType::BF16), 2);
assert_eq!(dtype_bytes(DType::F16), 2);
assert_eq!(dtype_bytes(DType::F32), 4);
assert_eq!(dtype_bytes(DType::F64), 8);
assert_eq!(dtype_bytes(DType::U8), 2);
}
fn gpu(ordinal: usize, free_gb: u64) -> DiscoveredGpu {
DiscoveredGpu {
ordinal,
name: format!("gpu{ordinal}"),
total_vram_bytes: 24 * GB,
free_vram_bytes: free_gb * GB,
}
}
#[test]
fn expand_picks_main_gpu_when_it_fits() {
let gpus = vec![gpu(0, 20), gpu(1, 20)];
assert_eq!(
select_expand_device(&gpus, 3 * GB, false),
ExpandPlacement::Gpu(0),
);
}
#[test]
fn expand_falls_through_to_second_gpu_when_main_full() {
let gpus = vec![gpu(0, 1), gpu(1, 10)];
assert_eq!(
select_expand_device(&gpus, 3 * GB, false),
ExpandPlacement::Gpu(1),
);
}
#[test]
fn expand_walks_all_gpus_in_ordinal_order() {
let gpus = vec![gpu(0, 1), gpu(1, 2), gpu(2, 10)];
assert_eq!(
select_expand_device(&gpus, 3 * GB, false),
ExpandPlacement::Gpu(2),
);
}
#[test]
fn expand_falls_back_to_cpu_when_no_gpu_fits() {
let gpus = vec![gpu(0, 1), gpu(1, 2)];
assert_eq!(
select_expand_device(&gpus, 3 * GB, false),
ExpandPlacement::Cpu,
);
}
#[test]
fn expand_falls_back_to_cpu_when_no_gpus_discovered() {
let gpus: Vec<DiscoveredGpu> = vec![];
assert_eq!(
select_expand_device(&gpus, 3 * GB, false),
ExpandPlacement::Cpu,
);
}
#[test]
fn expand_metal_always_picks_gpu_0_when_present() {
let gpus = vec![gpu(0, 0)];
assert_eq!(
select_expand_device(&gpus, 100 * GB, true),
ExpandPlacement::Gpu(0),
);
}
#[test]
fn expand_metal_with_no_gpus_goes_to_cpu() {
let gpus: Vec<DiscoveredGpu> = vec![];
assert_eq!(
select_expand_device(&gpus, 3 * GB, true),
ExpandPlacement::Cpu,
);
}
#[test]
fn expand_threshold_sums_weights_and_headroom() {
assert_eq!(expand_vram_threshold(4 * GB), 6 * GB);
assert_eq!(
expand_vram_threshold(1_300_000_000),
1_300_000_000 + EXPAND_ACTIVATION_HEADROOM,
);
}
#[test]
fn expand_strictly_greater_than_threshold() {
let gpus = vec![gpu(0, 3)]; assert_eq!(
select_expand_device(&gpus, 3 * GB, false),
ExpandPlacement::Cpu,
);
}
#[test]
fn expand_prefers_requested_gpu_when_it_fits() {
let gpus = vec![gpu(0, 20), gpu(1, 20)];
assert_eq!(
select_expand_device_with_preference(&gpus, 3 * GB, false, Some(1)),
ExpandPlacement::Gpu(1),
);
}
#[test]
fn expand_preference_falls_back_when_requested_gpu_cannot_fit() {
let gpus = vec![gpu(0, 20), gpu(1, 1)];
assert_eq!(
select_expand_device_with_preference(&gpus, 3 * GB, false, Some(1)),
ExpandPlacement::Gpu(0),
);
}
#[test]
fn select_ltx2_gemma_device_picks_active_gpu_when_room() {
let gpus = vec![gpu(0, 25)];
assert_eq!(
select_ltx2_gemma_device(&gpus, 0, 24 * GB),
LtxGemmaPlacement::Gpu(0),
);
}
#[test]
fn select_ltx2_gemma_device_falls_to_cpu_when_no_gpu_fits() {
let gpus = vec![gpu(0, 17)];
assert_eq!(
select_ltx2_gemma_device(&gpus, 0, 24 * GB),
LtxGemmaPlacement::Cpu,
);
}
#[test]
fn select_ltx2_gemma_device_picks_sibling_gpu_when_active_full() {
let gpus = vec![gpu(0, 4), gpu(1, 25)];
assert_eq!(
select_ltx2_gemma_device(&gpus, 0, 24 * GB),
LtxGemmaPlacement::Gpu(1),
);
}
#[test]
fn select_ltx2_gemma_device_walks_remaining_in_ordinal_order() {
let gpus = vec![gpu(0, 25), gpu(1, 4), gpu(2, 25)];
assert_eq!(
select_ltx2_gemma_device(&gpus, 1, 24 * GB),
LtxGemmaPlacement::Gpu(0),
);
}
#[test]
fn select_ltx2_gemma_device_returns_cpu_when_no_gpus_discovered() {
let gpus: Vec<DiscoveredGpu> = vec![];
assert_eq!(
select_ltx2_gemma_device(&gpus, 0, 24 * GB),
LtxGemmaPlacement::Cpu,
);
}
#[test]
fn ltx2_gemma_vram_threshold_is_24gb() {
assert_eq!(LTX2_GEMMA_VRAM_THRESHOLD, 24_000_000_000);
}
#[test]
fn resolve_ltx2_gemma_device_override_env_behaviors() {
let prior_main = std::env::var_os("MOLD_LTX2_GEMMA_DEVICE");
let prior_legacy = std::env::var_os("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
unsafe {
std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
std::env::remove_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
}
assert_eq!(resolve_ltx2_gemma_device_override(0), None);
unsafe { std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "cpu") };
assert_eq!(
resolve_ltx2_gemma_device_override(0),
Some(LtxGemmaPlacement::Cpu),
);
unsafe { std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "gpu") };
assert_eq!(
resolve_ltx2_gemma_device_override(1),
Some(LtxGemmaPlacement::Gpu(1)),
);
unsafe { std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "CPU") };
assert_eq!(
resolve_ltx2_gemma_device_override(0),
Some(LtxGemmaPlacement::Cpu),
);
unsafe { std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "auto") };
assert_eq!(resolve_ltx2_gemma_device_override(0), None);
unsafe { std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "wat") };
assert_eq!(resolve_ltx2_gemma_device_override(0), None);
unsafe {
std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
std::env::set_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER", "1");
}
assert_eq!(
resolve_ltx2_gemma_device_override(0),
Some(LtxGemmaPlacement::Cpu),
);
unsafe { std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "gpu") };
assert_eq!(
resolve_ltx2_gemma_device_override(2),
Some(LtxGemmaPlacement::Gpu(2)),
);
unsafe {
std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
std::env::remove_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
if let Some(v) = prior_main {
std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", v);
}
if let Some(v) = prior_legacy {
std::env::set_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER", v);
}
}
}
#[test]
fn test_keep_te_in_ram_env_behaviors() {
unsafe { std::env::remove_var("MOLD_KEEP_TE_RAM") };
assert!(!keep_te_in_ram(), "missing var must be off");
unsafe { std::env::set_var("MOLD_KEEP_TE_RAM", "1") };
assert!(keep_te_in_ram(), "\"1\" must enable park");
for v in ["", "0", "true", "yes", "TRUE"] {
unsafe { std::env::set_var("MOLD_KEEP_TE_RAM", v) };
assert!(
!keep_te_in_ram(),
"value {v:?} must not enable park (helper is strict ==\"1\")"
);
}
unsafe { std::env::remove_var("MOLD_KEEP_TE_RAM") };
}
#[test]
fn test_reserved_vram_and_usable_free_vram() {
unsafe { std::env::remove_var("MOLD_RESERVE_VRAM_MB") };
let default = reserved_vram_bytes();
#[cfg(target_os = "linux")]
assert_eq!(default, 400_000_000, "Linux default reserve = 400 MB");
#[cfg(target_os = "macos")]
assert_eq!(default, 0, "macOS default = 0 (Metal unified memory)");
unsafe { std::env::set_var("MOLD_RESERVE_VRAM_MB", "1024") };
assert_eq!(reserved_vram_bytes(), 1_024_000_000);
unsafe { std::env::set_var("MOLD_RESERVE_VRAM_MB", "0") };
assert_eq!(reserved_vram_bytes(), 0);
for v in ["", "abc", "-1"] {
unsafe { std::env::set_var("MOLD_RESERVE_VRAM_MB", v) };
assert_eq!(
reserved_vram_bytes(),
default,
"unparseable {v:?} must fall back to default"
);
}
unsafe { std::env::remove_var("MOLD_RESERVE_VRAM_MB") };
assert_eq!(usable_free_vram_from_raw(1_500, 500), 1_000);
assert_eq!(usable_free_vram_from_raw(500, 1_500), 0);
unsafe { std::env::set_var("MOLD_RESERVE_VRAM_MB", u64::MAX.to_string()) };
let has_raw_reading = free_vram_bytes(0).is_some();
let usable = usable_free_vram_bytes(0);
assert_eq!(
usable.is_some(),
has_raw_reading,
"usable_free_vram_bytes must mirror free_vram_bytes presence"
);
if has_raw_reading {
assert_eq!(usable, Some(0));
}
unsafe { std::env::remove_var("MOLD_RESERVE_VRAM_MB") };
}
#[test]
fn vram_load_delta_is_saturating_sub() {
assert_eq!(vram_load_delta(0, 0), 0);
assert_eq!(vram_load_delta(0, 1_000_000_000), 0);
assert_eq!(vram_load_delta(0, u64::MAX), 0);
}
fn write_dummy_file(dir: &std::path::Path, name: &str, size: u64) -> std::path::PathBuf {
let p = dir.join(name);
let f = std::fs::File::create(&p).expect("create dummy");
f.set_len(size).expect("set_len");
p
}
#[test]
fn estimate_peak_memory_single_file_does_not_double_count_vae() {
let dir = tempfile::tempdir().expect("tempdir");
let single = write_dummy_file(dir.path(), "single.safetensors", 44_000_000_000);
let te = write_dummy_file(dir.path(), "te.safetensors", 24_000_000_000);
let paths = mold_core::ModelPaths {
transformer: single.clone(),
transformer_shards: vec![],
vae: single, spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: None,
t5_tokenizer: None,
clip_tokenizer: None,
clip_encoder_2: None,
clip_tokenizer_2: None,
text_encoder_files: vec![te],
text_tokenizer: None,
decoder: None,
};
let peak = estimate_peak_memory(&paths, LoadStrategy::Sequential);
let peak_gb = peak as f64 / 1e9;
assert!(
peak_gb < 50.0,
"single-file peak should be ~46 GB, got {peak_gb:.1} GB (double-count bug returned)"
);
assert!(
peak_gb > 45.0,
"single-file peak should be ~46 GB, got {peak_gb:.1} GB"
);
}
#[test]
fn estimate_peak_memory_separate_vae_file_still_sums() {
let dir = tempfile::tempdir().expect("tempdir");
let transformer = write_dummy_file(dir.path(), "tx.safetensors", 4_000_000_000);
let vae = write_dummy_file(dir.path(), "vae.safetensors", 1_000_000_000);
let te = write_dummy_file(dir.path(), "te.safetensors", 9_000_000_000);
let paths = mold_core::ModelPaths {
transformer,
transformer_shards: vec![],
vae,
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: Some(te),
clip_encoder: None,
t5_tokenizer: None,
clip_tokenizer: None,
clip_encoder_2: None,
clip_tokenizer_2: None,
text_encoder_files: vec![],
text_tokenizer: None,
decoder: None,
};
let peak = estimate_peak_memory(&paths, LoadStrategy::Sequential);
let peak_gb = peak as f64 / 1e9;
assert!(
(10.5..11.5).contains(&peak_gb),
"expected ~11 GB, got {peak_gb:.1} GB"
);
}
#[test]
fn estimate_peak_memory_sharded_transformer_with_separate_vae_sums() {
let dir = tempfile::tempdir().expect("tempdir");
let s1 = write_dummy_file(dir.path(), "tx-1.safetensors", 4_000_000_000);
let s2 = write_dummy_file(dir.path(), "tx-2.safetensors", 4_000_000_000);
let vae = write_dummy_file(dir.path(), "vae.safetensors", 1_000_000_000);
let paths = mold_core::ModelPaths {
transformer: s1.clone(), transformer_shards: vec![s1, s2],
vae,
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: None,
t5_tokenizer: None,
clip_tokenizer: None,
clip_encoder_2: None,
clip_tokenizer_2: None,
text_encoder_files: vec![],
text_tokenizer: None,
decoder: None,
};
let peak = estimate_peak_memory(&paths, LoadStrategy::Sequential);
let peak_gb = peak as f64 / 1e9;
assert!(
(10.5..11.5).contains(&peak_gb),
"sharded peak should be ~11 GB, got {peak_gb:.1} GB"
);
}
#[test]
fn estimate_peak_memory_sharded_single_file_vae_does_not_double_count() {
let dir = tempfile::tempdir().expect("tempdir");
let single = write_dummy_file(dir.path(), "single.safetensors", 14_000_000_000);
let paths = mold_core::ModelPaths {
transformer: single.clone(),
transformer_shards: vec![single.clone()],
vae: single,
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: None,
t5_tokenizer: None,
clip_tokenizer: None,
clip_encoder_2: None,
clip_tokenizer_2: None,
text_encoder_files: vec![],
text_tokenizer: None,
decoder: None,
};
let peak = estimate_peak_memory(&paths, LoadStrategy::Sequential);
let peak_gb = peak as f64 / 1e9;
assert!(
(15.5..16.5).contains(&peak_gb),
"sharded single-file peak should be ~16 GB, got {peak_gb:.1} GB"
);
}
#[test]
fn estimate_peak_memory_single_file_sdxl_does_not_count_clip_views_as_full_checkpoints() {
let dir = tempfile::tempdir().expect("tempdir");
let single = write_dummy_file(dir.path(), "single.safetensors", 14_000_000_000);
let paths = mold_core::ModelPaths {
transformer: single.clone(),
transformer_shards: vec![],
vae: single.clone(),
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: Some(single.clone()),
t5_tokenizer: None,
clip_tokenizer: None,
clip_encoder_2: Some(single),
clip_tokenizer_2: None,
text_encoder_files: vec![],
text_tokenizer: None,
decoder: None,
};
let peak = estimate_peak_memory(&paths, LoadStrategy::Sequential);
let peak_gb = peak as f64 / 1e9;
assert!(
(15.5..16.5).contains(&peak_gb),
"single-file SDXL peak should be ~16 GB, got {peak_gb:.1} GB"
);
}
fn vae_env_lock() -> std::sync::MutexGuard<'static, ()> {
use std::sync::{Mutex, OnceLock};
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|p| p.into_inner())
}
#[test]
fn resolve_vae_dtype_unset_returns_default() {
let _g = vae_env_lock();
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
assert_eq!(
resolve_vae_dtype(candle_core::DType::BF16),
candle_core::DType::BF16
);
assert_eq!(
resolve_vae_dtype(candle_core::DType::F16),
candle_core::DType::F16
);
}
#[test]
fn resolve_vae_dtype_auto_returns_default() {
let _g = vae_env_lock();
unsafe { std::env::set_var("MOLD_VAE_DTYPE", "auto") };
let resolved = resolve_vae_dtype(candle_core::DType::BF16);
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
assert_eq!(resolved, candle_core::DType::BF16);
}
#[test]
fn resolve_vae_dtype_fp32_forces_f32_regardless_of_default() {
let _g = vae_env_lock();
unsafe { std::env::set_var("MOLD_VAE_DTYPE", "fp32") };
let resolved = resolve_vae_dtype(candle_core::DType::BF16);
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
assert_eq!(resolved, candle_core::DType::F32);
}
#[test]
fn resolve_vae_dtype_bf16_forces_bf16_even_when_default_is_f32() {
let _g = vae_env_lock();
unsafe { std::env::set_var("MOLD_VAE_DTYPE", "bf16") };
let resolved = resolve_vae_dtype(candle_core::DType::F32);
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
assert_eq!(resolved, candle_core::DType::BF16);
}
#[test]
fn resolve_vae_dtype_fp16_alias_recognised() {
let _g = vae_env_lock();
for value in ["fp16", "f16", "FP16", "F16"] {
unsafe { std::env::set_var("MOLD_VAE_DTYPE", value) };
let resolved = resolve_vae_dtype(candle_core::DType::BF16);
assert_eq!(
resolved,
candle_core::DType::F16,
"value `{value}` should resolve to F16"
);
}
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
}
#[test]
fn resolve_vae_dtype_invalid_value_falls_back_to_default() {
let _g = vae_env_lock();
unsafe { std::env::set_var("MOLD_VAE_DTYPE", "fp64") };
let resolved = resolve_vae_dtype(candle_core::DType::BF16);
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
assert_eq!(
resolved,
candle_core::DType::BF16,
"invalid value must fall back, not error"
);
}
#[test]
fn memory_budget_headroom_is_2gb() {
assert_eq!(
MEMORY_BUDGET_HEADROOM, 2_000_000_000,
"MEMORY_BUDGET_HEADROOM changed — update the rejection error message \
in mold-server::model_manager::check_model_memory_budget to match"
);
}
}