use anyhow::Result;
use candle_core::Tensor;
#[cfg(feature = "cuda")]
use std::ffi::c_void;
use std::sync::{Arc, Mutex};
const GB: u64 = 1_000_000_000;
pub fn pinned_cap_bytes() -> u64 {
if let Ok(v) = std::env::var("MOLD_PINNED_VRAM_MAX_GB") {
if let Ok(gb) = v.trim().parse::<f64>() {
if gb > 0.0 {
return (gb * GB as f64) as u64;
}
}
}
let total = total_system_ram_bytes().unwrap_or(32 * GB);
let frac = if cfg!(target_os = "macos") { 0.4 } else { 0.5 };
((total as f64) * frac) as u64
}
pub fn prefetch_enabled_from_env() -> bool {
match std::env::var("MOLD_OFFLOAD_PREFETCH") {
Ok(v) => !matches!(
v.trim().to_ascii_lowercase().as_str(),
"off" | "0" | "false"
),
Err(_) => true,
}
}
#[cfg(target_os = "linux")]
pub fn total_system_ram_bytes() -> Option<u64> {
let meminfo = std::fs::read_to_string("/proc/meminfo").ok()?;
for line in meminfo.lines() {
if let Some(rest) = line.strip_prefix("MemTotal:") {
let mut it = rest.split_ascii_whitespace();
let val: u64 = it.next()?.parse().ok()?;
return Some(val.saturating_mul(1024));
}
}
None
}
#[cfg(not(target_os = "linux"))]
pub fn total_system_ram_bytes() -> Option<u64> {
None
}
pub fn largest_block_size_bytes(sizes: &[usize]) -> usize {
sizes.iter().copied().max().unwrap_or(0)
}
#[derive(Debug, Clone)]
pub struct PinnedMemoryTracker {
cap_bytes: u64,
used: Arc<Mutex<u64>>,
capped_warning_issued: Arc<Mutex<bool>>,
}
impl PinnedMemoryTracker {
pub fn new(cap_bytes: u64) -> Self {
Self {
cap_bytes,
used: Arc::new(Mutex::new(0)),
capped_warning_issued: Arc::new(Mutex::new(false)),
}
}
#[allow(dead_code)]
pub fn cap_bytes(&self) -> u64 {
self.cap_bytes
}
pub fn used_bytes(&self) -> u64 {
*self.used.lock().unwrap()
}
pub fn try_reserve(&self, n: u64) -> bool {
let mut used = self.used.lock().unwrap();
if used.saturating_add(n) > self.cap_bytes {
let mut warned = self.capped_warning_issued.lock().unwrap();
if !*warned {
tracing::info!(
"FLUX offload: pinned-memory soft cap reached ({:.2} GB used, cap {:.2} GB) — \
remaining blocks will fall back to pageable copies. \
Override with MOLD_PINNED_VRAM_MAX_GB.",
*used as f64 / GB as f64,
self.cap_bytes as f64 / GB as f64,
);
*warned = true;
}
false
} else {
*used += n;
true
}
}
pub fn release(&self, n: u64) {
let mut used = self.used.lock().unwrap();
*used = used.saturating_sub(n);
}
}
pub struct PinnedRegion {
#[cfg(feature = "cuda")]
ptr: *mut c_void,
n_bytes: u64,
tracker: PinnedMemoryTracker,
}
#[cfg(feature = "cuda")]
unsafe impl Send for PinnedRegion {}
#[cfg(feature = "cuda")]
unsafe impl Sync for PinnedRegion {}
impl Drop for PinnedRegion {
fn drop(&mut self) {
#[cfg(feature = "cuda")]
unsafe {
use candle_core::cuda_backend::cudarc::driver::sys;
let res = sys::cuMemHostUnregister(self.ptr);
if res != sys::CUresult::CUDA_SUCCESS {
tracing::debug!(
"cuMemHostUnregister returned {:?} for {} bytes (continuing)",
res,
self.n_bytes
);
}
}
self.tracker.release(self.n_bytes);
}
}
pub fn try_pin_to_host(
tensor: &Tensor,
tracker: &PinnedMemoryTracker,
) -> Result<Option<PinnedRegion>> {
if !tensor.device().is_cpu() {
return Ok(None);
}
let view = match cpu_tensor_byte_view(tensor)? {
Some(v) => v,
None => return Ok(None),
};
let (ptr, n_bytes) = view;
if n_bytes == 0 {
return Ok(None);
}
if !tracker.try_reserve(n_bytes as u64) {
return Ok(None);
}
#[cfg(feature = "cuda")]
{
use candle_core::cuda_backend::cudarc::driver::sys;
const CU_MEMHOSTREGISTER_PORTABLE: u32 = 1;
let res = unsafe {
sys::cuMemHostRegister_v2(ptr as *mut c_void, n_bytes, CU_MEMHOSTREGISTER_PORTABLE)
};
if res != sys::CUresult::CUDA_SUCCESS {
tracker.release(n_bytes as u64);
tracing::debug!(
"cuMemHostRegister_v2 returned {:?} for {} bytes — falling back to pageable",
res,
n_bytes
);
return Ok(None);
}
Ok(Some(PinnedRegion {
ptr: ptr as *mut c_void,
n_bytes: n_bytes as u64,
tracker: tracker.clone(),
}))
}
#[cfg(not(feature = "cuda"))]
{
tracker.release(n_bytes as u64);
let _ = ptr; Ok(None)
}
}
fn cpu_tensor_byte_view(tensor: &Tensor) -> Result<Option<(*const u8, usize)>> {
use candle_core::{DType, Storage};
if !tensor.is_contiguous() {
return Ok(None);
}
let (storage, layout) = tensor.storage_and_layout();
let cpu = match &*storage {
Storage::Cpu(c) => c,
_ => return Ok(None),
};
let base_offset_bytes = layout.start_offset() * tensor.dtype().size_in_bytes();
let elem_bytes = tensor.elem_count() * tensor.dtype().size_in_bytes();
let (vec_ptr, vec_bytes): (*const u8, usize) = match (cpu, tensor.dtype()) {
(candle_core::CpuStorage::U8(v), DType::U8) => {
(v.as_ptr(), std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::U32(v), DType::U32) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::I16(v), DType::I16) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::I32(v), DType::I32) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::I64(v), DType::I64) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::BF16(v), DType::BF16) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::F16(v), DType::F16) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::F32(v), DType::F32) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::F64(v), DType::F64) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
(candle_core::CpuStorage::F8E4M3(v), DType::F8E4M3) => {
(v.as_ptr() as *const u8, std::mem::size_of_val(v.as_slice()))
}
_ => return Ok(None),
};
if base_offset_bytes + elem_bytes > vec_bytes {
return Ok(None);
}
Ok(Some((vec_ptr, vec_bytes)))
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device, Tensor};
#[test]
fn try_pin_to_host_no_op_on_cpu_tensor() {
let t = Tensor::zeros((4, 4), DType::F32, &Device::Cpu).unwrap();
let tracker = PinnedMemoryTracker::new(10 * GB);
let r = try_pin_to_host(&t, &tracker).expect("pinning a CPU tensor must not error");
let _ = r;
}
#[test]
fn pinned_memory_tracker_caps_total_bytes() {
let t = PinnedMemoryTracker::new(100);
assert!(
t.try_reserve(40),
"first reservation under cap must succeed"
);
assert!(
t.try_reserve(50),
"second reservation that fits must succeed"
);
assert_eq!(t.used_bytes(), 90);
assert!(
!t.try_reserve(20),
"reservation that would exceed the cap must be rejected"
);
assert_eq!(
t.used_bytes(),
90,
"rejected reservation must not consume budget"
);
assert!(t.try_reserve(10), "exactly-fits reservation must succeed");
assert_eq!(t.used_bytes(), 100);
t.release(40);
assert_eq!(t.used_bytes(), 60);
assert!(
t.try_reserve(40),
"release should let new reservations through"
);
}
#[test]
fn prefetch_buffer_sized_for_largest_block() {
assert_eq!(largest_block_size_bytes(&[]), 0);
assert_eq!(largest_block_size_bytes(&[100]), 100);
assert_eq!(largest_block_size_bytes(&[100, 200, 50]), 200);
assert_eq!(largest_block_size_bytes(&[7, 7, 7]), 7);
assert_eq!(
largest_block_size_bytes(&[1, 1_000_000_000, 999]),
1_000_000_000
);
}
#[test]
fn prefetch_disabled_via_env() {
unsafe { std::env::remove_var("MOLD_OFFLOAD_PREFETCH") };
assert!(
prefetch_enabled_from_env(),
"missing var must default to enabled"
);
for off in ["off", "OFF", "0", "false", "False"] {
unsafe { std::env::set_var("MOLD_OFFLOAD_PREFETCH", off) };
assert!(
!prefetch_enabled_from_env(),
"value {off:?} must disable prefetch"
);
}
for on in ["on", "1", "true", "yes", "anything-else"] {
unsafe { std::env::set_var("MOLD_OFFLOAD_PREFETCH", on) };
assert!(
prefetch_enabled_from_env(),
"value {on:?} must keep prefetch enabled"
);
}
unsafe { std::env::remove_var("MOLD_OFFLOAD_PREFETCH") };
}
#[test]
fn try_pin_returns_none_when_tracker_cap_exceeded() {
let t = Tensor::ones((16, 16), DType::F32, &Device::Cpu).unwrap();
let tracker = PinnedMemoryTracker::new(0);
let r = try_pin_to_host(&t, &tracker).expect("zero-cap pin must not error");
assert!(r.is_none(), "zero-cap tracker must yield no pinned region");
assert!(!tracker.try_reserve(1));
assert!(!tracker.try_reserve(1));
}
#[test]
fn try_pin_handles_every_supported_cpu_dtype() {
let device = Device::Cpu;
let tracker = PinnedMemoryTracker::new(10 * GB);
for dtype in [
DType::U8,
DType::U32,
DType::I64,
DType::F32,
DType::F64,
DType::BF16,
DType::F16,
] {
let t = Tensor::zeros((8, 8), dtype, &device).unwrap();
try_pin_to_host(&t, &tracker)
.unwrap_or_else(|e| panic!("dtype {dtype:?} broke try_pin_to_host: {e}"));
}
}
#[test]
fn try_pin_skips_non_contiguous_views() {
let base = Tensor::ones((8, 16), DType::F32, &Device::Cpu).unwrap();
let view = base.transpose(0, 1).unwrap();
assert!(
!view.is_contiguous(),
"transposed view must be non-contiguous"
);
let tracker = PinnedMemoryTracker::new(10 * GB);
let r = try_pin_to_host(&view, &tracker).expect("non-contiguous must not error");
assert!(r.is_none(), "non-contiguous tensors must skip pinning");
assert_eq!(
tracker.used_bytes(),
0,
"no reservation may charge against the cap when pin is skipped"
);
}
#[test]
fn try_pin_skips_when_byte_count_is_zero() {
let t = Tensor::zeros((0, 8), DType::F32, &Device::Cpu).unwrap();
let tracker = PinnedMemoryTracker::new(10 * GB);
let r = try_pin_to_host(&t, &tracker).expect("empty tensor must not error");
assert!(r.is_none(), "empty tensors must skip pinning");
assert_eq!(tracker.used_bytes(), 0);
}
#[test]
fn pinned_memory_tracker_cap_bytes_accessor_returns_construction_value() {
let t = PinnedMemoryTracker::new(7 * GB);
assert_eq!(t.cap_bytes(), 7 * GB);
}
#[test]
fn pinned_cap_respects_env_override() {
unsafe { std::env::remove_var("MOLD_PINNED_VRAM_MAX_GB") };
let baseline = pinned_cap_bytes();
assert!(baseline > 0, "default cap must be positive");
unsafe { std::env::set_var("MOLD_PINNED_VRAM_MAX_GB", "8") };
assert_eq!(pinned_cap_bytes(), 8 * GB);
unsafe { std::env::set_var("MOLD_PINNED_VRAM_MAX_GB", "0.5") };
assert_eq!(pinned_cap_bytes(), GB / 2);
unsafe { std::env::set_var("MOLD_PINNED_VRAM_MAX_GB", "garbage") };
assert_eq!(pinned_cap_bytes(), baseline);
unsafe { std::env::remove_var("MOLD_PINNED_VRAM_MAX_GB") };
}
}