use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::{Mutex, OnceLock};
use tracing::{debug, info, warn};
static GLOBAL_TRACKER: OnceLock<Mutex<Option<MemoryReportTracker>>> = OnceLock::new();
pub fn init_tracker(tracker: MemoryReportTracker) {
let _ = GLOBAL_TRACKER.set(Mutex::new(Some(tracker)));
}
pub fn observe_rss(phase: &str) {
if let Some(tracker) = GLOBAL_TRACKER.get() {
let rss = current_rss_bytes();
let mut guard = tracker.lock().unwrap_or_else(|p| p.into_inner());
if let Some(t) = guard.as_mut() {
t.observe_rss(rss);
}
debug!("Memory observation: phase={}, rss={} bytes", phase, rss);
}
}
pub fn record_phase(name: &str, peak_rss_bytes: u64, sample_count: u64) {
if let Some(tracker) = GLOBAL_TRACKER.get() {
let mut guard = tracker.lock().unwrap_or_else(|p| p.into_inner());
if let Some(t) = guard.as_mut() {
t.record_phase(name, peak_rss_bytes, sample_count);
}
}
}
pub fn shutdown() {
if let Some(tracker) = GLOBAL_TRACKER.get() {
let mut guard = tracker.lock().unwrap_or_else(|p| p.into_inner());
if let Some(mut t) = guard.take() {
if let Err(e) = t.write_report() {
warn!("Failed to write memory report on shutdown: {}", e);
}
}
}
}
pub const MEMORY_REPORT_ENV: &str = "LEINDEX_MEMORY_REPORT";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryReport {
pub version: u32,
pub uptime_secs: f64,
pub peak_rss_bytes: u64,
pub timestamp_secs: u64,
pub phases: Vec<PhaseSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PhaseSummary {
pub name: String,
pub peak_rss_bytes: u64,
pub sample_count: u64,
}
pub struct MemoryReportTracker {
path: PathBuf,
start: std::time::Instant,
peak_rss_bytes: u64,
phases: Vec<PhaseSummary>,
written: bool,
}
impl MemoryReportTracker {
pub fn new(path: PathBuf) -> Self {
debug!("Memory report will be written to {}", path.display());
Self {
path,
start: std::time::Instant::now(),
peak_rss_bytes: 0,
phases: Vec::new(),
written: false,
}
}
pub fn observe_rss(&mut self, rss_bytes: u64) {
if rss_bytes > self.peak_rss_bytes {
self.peak_rss_bytes = rss_bytes;
}
}
pub fn record_phase(
&mut self,
name: impl Into<String>,
peak_rss_bytes: u64,
sample_count: u64,
) {
self.phases.push(PhaseSummary {
name: name.into(),
peak_rss_bytes,
sample_count,
});
if peak_rss_bytes > self.peak_rss_bytes {
self.peak_rss_bytes = peak_rss_bytes;
}
}
pub fn write_report(&mut self) -> std::io::Result<()> {
if self.written {
return Ok(());
}
self.written = true;
let report = MemoryReport {
version: 1,
uptime_secs: self.start.elapsed().as_secs_f64(),
peak_rss_bytes: self.peak_rss_bytes,
timestamp_secs: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
phases: self.phases.clone(),
};
if let Some(parent) = self.path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let json = serde_json::to_string_pretty(&report)?;
std::fs::write(&self.path, json)?;
info!("Memory report written to {}", self.path.display());
Ok(())
}
}
impl Drop for MemoryReportTracker {
fn drop(&mut self) {
if let Err(e) = self.write_report() {
warn!("Failed to write memory report: {}", e);
}
}
}
pub fn resolve_report_path(cli_flag: Option<&Path>) -> Option<PathBuf> {
if let Some(p) = cli_flag {
return Some(p.to_path_buf());
}
std::env::var(MEMORY_REPORT_ENV)
.ok()
.filter(|s| !s.is_empty())
.map(PathBuf::from)
}
pub fn current_rss_bytes() -> u64 {
#[cfg(target_os = "linux")]
{
read_rss_procfs_bytes().unwrap_or(0)
}
#[cfg(not(target_os = "linux"))]
{
read_rss_sysinfo_bytes().unwrap_or(0)
}
}
#[cfg(target_os = "linux")]
fn read_rss_procfs_bytes() -> Option<u64> {
let status = std::fs::read_to_string("/proc/self/status").ok()?;
for line in status.lines() {
if line.starts_with("VmRSS:") {
let kb: u64 = line.split_whitespace().nth(1)?.parse().ok()?;
return Some(kb * 1024);
}
}
None
}
#[cfg(not(target_os = "linux"))]
fn read_rss_sysinfo_bytes() -> Option<u64> {
use sysinfo::System;
let mut sys = System::new();
let pid = sysinfo::Pid::from(std::process::id() as usize);
let pid_list = [pid];
sys.refresh_processes(sysinfo::ProcessesToUpdate::Some(&pid_list), true);
sys.process(pid).map(|p| p.memory())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn test_resolve_report_path_none_when_unset() {
std::env::remove_var(MEMORY_REPORT_ENV);
assert!(resolve_report_path(None).is_none());
}
#[test]
fn test_resolve_report_path_from_flag() {
let path = Path::new("/tmp/test-report.json");
let result = resolve_report_path(Some(path));
assert_eq!(result, Some(PathBuf::from("/tmp/test-report.json")));
}
#[test]
fn test_resolve_report_path_from_env() {
std::env::remove_var(MEMORY_REPORT_ENV);
std::env::set_var(MEMORY_REPORT_ENV, "/tmp/env-report.json");
let result = resolve_report_path(None);
assert_eq!(result, Some(PathBuf::from("/tmp/env-report.json")));
std::env::remove_var(MEMORY_REPORT_ENV);
}
#[test]
fn test_flag_takes_precedence_over_env() {
std::env::set_var(MEMORY_REPORT_ENV, "/tmp/env-report.json");
let flag_path = Path::new("/tmp/flag-report.json");
let result = resolve_report_path(Some(flag_path));
assert_eq!(result, Some(PathBuf::from("/tmp/flag-report.json")));
std::env::remove_var(MEMORY_REPORT_ENV);
}
#[test]
fn test_empty_env_var_ignored() {
std::env::remove_var(MEMORY_REPORT_ENV);
std::env::set_var(MEMORY_REPORT_ENV, "");
let result = resolve_report_path(None);
assert!(
result.is_none(),
"empty env var should be ignored, got {:?}",
result
);
std::env::remove_var(MEMORY_REPORT_ENV);
}
#[test]
fn test_tracker_writes_report() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("report.json");
let mut tracker = MemoryReportTracker::new(path.clone());
tracker.observe_rss(100_000_000);
tracker.record_phase("index", 150_000_000, 42);
tracker.write_report().unwrap();
let contents = fs::read_to_string(&path).unwrap();
let report: MemoryReport = serde_json::from_str(&contents).unwrap();
assert_eq!(report.version, 1);
assert_eq!(report.peak_rss_bytes, 150_000_000);
assert_eq!(report.phases.len(), 1);
assert_eq!(report.phases[0].name, "index");
assert_eq!(report.phases[0].peak_rss_bytes, 150_000_000);
assert_eq!(report.phases[0].sample_count, 42);
assert!(report.uptime_secs >= 0.0);
assert!(report.timestamp_secs > 0);
}
#[test]
fn test_tracker_drop_writes_report() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("drop-report.json");
{
let mut tracker = MemoryReportTracker::new(path.clone());
tracker.observe_rss(50_000_000);
}
let contents = fs::read_to_string(&path).unwrap();
let report: MemoryReport = serde_json::from_str(&contents).unwrap();
assert_eq!(report.peak_rss_bytes, 50_000_000);
}
#[test]
fn test_report_is_compact() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("compact-report.json");
let mut tracker = MemoryReportTracker::new(path.clone());
tracker.observe_rss(100_000_000);
tracker.record_phase("index", 150_000_000, 42);
tracker.record_phase("query", 120_000_000, 10);
tracker.write_report().unwrap();
let contents = fs::read_to_string(&path).unwrap();
assert!(
contents.len() < 2048,
"Report should be compact, got {} bytes",
contents.len()
);
}
#[test]
fn test_double_write_is_idempotent() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("idempotent-report.json");
let mut tracker = MemoryReportTracker::new(path.clone());
tracker.observe_rss(100_000_000);
tracker.write_report().unwrap();
tracker.write_report().unwrap();
let contents = fs::read_to_string(&path).unwrap();
let report: MemoryReport = serde_json::from_str(&contents).unwrap();
assert_eq!(report.peak_rss_bytes, 100_000_000);
}
#[test]
fn test_current_rss_bytes_is_reasonable() {
let rss = current_rss_bytes();
assert!(rss > 0, "RSS should be positive, got {}", rss);
assert!(
rss < 10_000_000_000,
"RSS should be less than 10 GB, got {}",
rss
);
}
}