use std::sync::atomic::{AtomicI64, Ordering};
static TOTAL_BYTES: AtomicI64 = AtomicI64::new(0);
static WEIGHT_BYTES: AtomicI64 = AtomicI64::new(0);
static SCRATCH_BYTES: AtomicI64 = AtomicI64::new(0);
static KV_BYTES: AtomicI64 = AtomicI64::new(0);
static LORA_BYTES: AtomicI64 = AtomicI64::new(0);
static OTHER_BYTES: AtomicI64 = AtomicI64::new(0);
fn category(label: &str) -> &'static AtomicI64 {
if label.starts_with("weight:") {
&WEIGHT_BYTES
} else if label.starts_with("scratch")
|| label.starts_with("ckpt")
|| label.starts_with("layer.")
{
&SCRATCH_BYTES
} else if label.starts_with("kv") {
&KV_BYTES
} else if label.starts_with("lora") || label.starts_with("adam") {
&LORA_BYTES
} else {
&OTHER_BYTES
}
}
#[cfg(not(target_arch = "wasm32"))]
const LOG_THRESHOLD: u64 = 1 << 20;
pub fn snapshot_mib() -> (i64, i64, i64, i64, i64, i64) {
(
TOTAL_BYTES.load(Ordering::Relaxed) >> 20,
WEIGHT_BYTES.load(Ordering::Relaxed) >> 20,
SCRATCH_BYTES.load(Ordering::Relaxed) >> 20,
KV_BYTES.load(Ordering::Relaxed) >> 20,
LORA_BYTES.load(Ordering::Relaxed) >> 20,
OTHER_BYTES.load(Ordering::Relaxed) >> 20,
)
}
pub fn breakdown_str() -> String {
let (t, w, s, kv, l, o) = snapshot_mib();
format!("tot={t} w={w} s={s} kv={kv} lora={l} o={o}")
}
pub fn record_alloc(label: &str, bytes: u64) {
let total = TOTAL_BYTES.fetch_add(bytes as i64, Ordering::Relaxed) + bytes as i64;
category(label).fetch_add(bytes as i64, Ordering::Relaxed);
#[cfg(not(target_arch = "wasm32"))]
if bytes >= LOG_THRESHOLD {
log_line(&format!(
"ALLOC {label} +{}MiB total={}MiB",
bytes >> 20,
total >> 20
));
}
#[cfg(target_arch = "wasm32")]
let _ = total;
}
pub fn record_free(label: &str, bytes: u64) {
let total = TOTAL_BYTES.fetch_sub(bytes as i64, Ordering::Relaxed) - bytes as i64;
category(label).fetch_sub(bytes as i64, Ordering::Relaxed);
#[cfg(not(target_arch = "wasm32"))]
if bytes >= LOG_THRESHOLD {
log_line(&format!(
"FREE {label} -{}MiB total={}MiB",
bytes >> 20,
total >> 20
));
}
#[cfg(target_arch = "wasm32")]
let _ = total;
}
pub fn total_mib() -> i64 {
TOTAL_BYTES.load(Ordering::Relaxed) >> 20
}
pub fn mark(label: &str) {
log_line(&format!("MARK {label} total={}MiB", total_mib()));
}
fn log_line(s: &str) {
#[cfg(target_arch = "wasm32")]
web_sys::console::log_1(&wasm_bindgen::JsValue::from_str(&format!("[gpumem] {s}")));
#[cfg(not(target_arch = "wasm32"))]
{
use std::sync::OnceLock;
static ON: OnceLock<bool> = OnceLock::new();
if *ON.get_or_init(|| std::env::var("RULLAMA_TRACE_MEM").is_ok()) {
eprintln!("[gpumem] {s}");
}
}
}