use bytesize::ByteSize;
use log::{Level, info, log_enabled};
#[cfg(feature = "memory-debug")]
pub use platform_ffi::print_mi_stats;
pub use platform_ffi::{force_mi_collect, process_rss_bytes};
const TARGET: &str = "fgumi_lib::sort::memory_probe";
#[allow(unsafe_code)]
mod platform_ffi {
pub fn force_mi_collect() {
unsafe {
libmimalloc_sys::mi_collect(true);
}
}
#[cfg(feature = "memory-debug")]
pub fn print_mi_stats() {
unsafe {
libmimalloc_sys::mi_stats_print_out(None, std::ptr::null_mut());
}
}
#[cfg(target_os = "macos")]
#[must_use]
pub fn process_rss_bytes() -> Option<u64> {
use mach2::message::mach_msg_type_number_t;
use mach2::task::task_info;
use mach2::task_info::{TASK_VM_INFO, task_vm_info};
use mach2::traps::mach_task_self;
use mach2::vm_types::natural_t;
let mut info = task_vm_info::default();
let mut count: mach_msg_type_number_t = mach_msg_type_number_t::try_from(
std::mem::size_of::<task_vm_info>() / std::mem::size_of::<natural_t>(),
)
.ok()?;
let kr = unsafe {
task_info(
mach_task_self(),
TASK_VM_INFO,
std::ptr::addr_of_mut!(info).cast(),
std::ptr::addr_of_mut!(count),
)
};
if kr != 0 {
return None;
}
Some(info.phys_footprint)
}
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
#[must_use]
pub fn process_rss_bytes() -> Option<u64> {
None
}
#[cfg(target_os = "linux")]
#[must_use]
pub fn process_rss_bytes() -> Option<u64> {
let status = std::fs::read_to_string("/proc/self/status").ok()?;
status
.lines()
.find(|line| line.starts_with("VmRSS:"))?
.split_whitespace()
.nth(1)?
.parse::<u64>()
.ok()
.map(|kb| kb * 1024)
}
}
#[must_use]
#[inline]
pub fn enabled() -> bool {
log_enabled!(target: TARGET, Level::Info)
}
fn fmt_bytes(bytes: u64) -> String {
ByteSize(bytes).to_string()
}
struct RssSnapshot {
rss_str: String,
post_collect_str: String,
collected_str: String,
rss: Option<u64>,
}
fn collect_rss_snapshot() -> RssSnapshot {
let rss = process_rss_bytes();
let rss_str = rss.map_or_else(|| "?".to_string(), fmt_bytes);
force_mi_collect();
let rss_after = process_rss_bytes();
let post_collect_str = rss_after.map_or_else(|| "?".to_string(), fmt_bytes);
let collected_str = match (rss, rss_after) {
(Some(a), Some(b)) => fmt_bytes(a.saturating_sub(b)),
_ => "?".to_string(),
};
RssSnapshot { rss_str, post_collect_str, collected_str, rss }
}
pub fn log_snapshot(label: &str, tracked_bytes: u64) {
if !enabled() {
return;
}
let snap = collect_rss_snapshot();
let tracked_str = fmt_bytes(tracked_bytes);
match snap.rss {
Some(r) => {
let residual = r.saturating_sub(tracked_bytes);
let residual_str = fmt_bytes(residual);
let pct = if r == 0 { 0 } else { (residual * 100) / r };
info!(
target: TARGET,
"MEM[{label}] rss={} post_collect={} collected={} tracked={tracked_str} residual={residual_str} residual_pct={pct}%",
snap.rss_str, snap.post_collect_str, snap.collected_str,
);
}
None => {
info!(
target: TARGET,
"MEM[{label}] rss=? tracked={tracked_str} residual=? residual_pct=?"
);
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct BufferProbeStats {
pub usage: u64,
pub capacity: u64,
pub records: u64,
pub segments: u64,
}
impl BufferProbeStats {
#[must_use]
pub fn simple(usage: u64, records: u64) -> Self {
Self { usage, capacity: 0, records, segments: 0 }
}
}
fn log_phase1_snapshot(
label: &str,
buf_stats: Option<BufferProbeStats>,
pool_depths: Option<(usize, usize, usize)>,
) {
if !enabled() {
return;
}
let snap = collect_rss_snapshot();
let buf_str = match buf_stats {
Some(s) => format!(
" buf_use={} buf_cap={} recs={} segs={}",
fmt_bytes(s.usage),
fmt_bytes(s.capacity),
s.records,
s.segments,
),
None => String::new(),
};
let pool_str = match pool_depths {
Some((raw_q, decomp_q, buf_q)) => {
format!(" raw_q={raw_q} decomp_q={decomp_q} buf_q={buf_q}")
}
None => String::new(),
};
let tracked = buf_stats.map_or(0, |s| s.usage);
let residual_str = match snap.rss {
Some(r) => {
let residual = r.saturating_sub(tracked);
let pct = if r == 0 { 0 } else { (residual * 100) / r };
format!(" residual={} residual_pct={pct}%", fmt_bytes(residual))
}
None => " residual=? residual_pct=?".to_string(),
};
info!(
target: TARGET,
"MEM[{label}] rss={} post_collect={} collected={}{buf_str}{pool_str}{residual_str}",
snap.rss_str, snap.post_collect_str, snap.collected_str,
);
}
pub struct SpillProbe {
phase: &'static str,
spill_idx: usize,
read_sample_idx: usize,
next_read_threshold: u64,
}
impl SpillProbe {
pub const READ_SAMPLE_INTERVAL: u64 = 1_000_000;
#[must_use]
pub fn new(phase: &'static str) -> Self {
log_snapshot(&format!("{phase}.start"), 0);
Self {
phase,
spill_idx: 0,
read_sample_idx: 0,
next_read_threshold: Self::READ_SAMPLE_INTERVAL,
}
}
#[inline]
#[must_use]
pub fn should_sample_read(&self, records_read: u64) -> bool {
enabled() && records_read >= self.next_read_threshold
}
pub fn log_mid_read(
&mut self,
buf_stats: BufferProbeStats,
pool_depths: Option<(usize, usize, usize)>,
) {
log_phase1_snapshot(
&format!("{}.mid_read_{}", self.phase, self.read_sample_idx),
Some(buf_stats),
pool_depths,
);
self.read_sample_idx += 1;
self.next_read_threshold =
self.next_read_threshold.saturating_add(Self::READ_SAMPLE_INTERVAL);
}
pub fn pre_spill(
&self,
buf_stats: BufferProbeStats,
pool_depths: Option<(usize, usize, usize)>,
) {
log_phase1_snapshot(
&format!("{}.pre_spill_{}", self.phase, self.spill_idx),
Some(buf_stats),
pool_depths,
);
}
pub fn post_drain(
&self,
buf_stats: BufferProbeStats,
pool_depths: Option<(usize, usize, usize)>,
) {
log_phase1_snapshot(
&format!("{}.post_drain_{}", self.phase, self.spill_idx),
Some(buf_stats),
pool_depths,
);
}
pub fn post_spill(&mut self, pool_depths: Option<(usize, usize, usize)>) {
log_phase1_snapshot(
&format!("{}.post_spill_{}", self.phase, self.spill_idx),
None,
pool_depths,
);
self.spill_idx += 1;
self.read_sample_idx = 0;
}
pub fn phase1_end(&self, tracked: u64) {
log_snapshot(&format!("{}.end", self.phase), tracked);
}
#[cfg(test)]
#[must_use]
pub fn spill_count(&self) -> usize {
self.spill_idx
}
}
#[derive(Copy, Clone, Debug)]
pub struct ConsumerProbeStats {
pub current_bytes: u64,
pub current_capacity: u64,
pub pending_blocks: u64,
pub pending_bytes: u64,
pub active_sources: u64,
}
pub struct MergeProbe {
sample_idx: usize,
next_threshold: u64,
}
impl MergeProbe {
pub const SAMPLE_INTERVAL_RECORDS: u64 = 1_000_000;
#[must_use]
pub fn new() -> Self {
log_snapshot("phase2.start", 0);
Self { sample_idx: 0, next_threshold: Self::SAMPLE_INTERVAL_RECORDS }
}
#[cfg(test)]
#[inline]
pub fn record(&mut self, records_merged: u64) {
if records_merged < self.next_threshold {
return;
}
if enabled() {
log_snapshot(&format!("phase2.mid_{}", self.sample_idx), 0);
}
self.sample_idx += 1;
self.next_threshold = self.next_threshold.saturating_add(Self::SAMPLE_INTERVAL_RECORDS);
}
#[cfg(test)]
#[must_use]
pub fn sample_count(&self) -> usize {
self.sample_idx
}
#[inline]
#[must_use]
pub fn should_sample(&self, records_merged: u64) -> bool {
enabled() && records_merged >= self.next_threshold
}
pub fn log_mid_with_depths(
&mut self,
pool_depths: (usize, usize, usize),
consumer_stats: Option<ConsumerProbeStats>,
) {
if enabled() {
let snap = collect_rss_snapshot();
let (raw_q, decomp_q, buf_q) = pool_depths;
let consumer_str = match consumer_stats {
Some(s) => format!(
" cur_bytes={} cur_cap={} pend_blocks={} pend_bytes={} active_src={}",
fmt_bytes(s.current_bytes),
fmt_bytes(s.current_capacity),
s.pending_blocks,
fmt_bytes(s.pending_bytes),
s.active_sources,
),
None => String::new(),
};
info!(
target: TARGET,
"MEM[phase2.mid_{}] rss={} post_collect={} collected={} raw_q={raw_q} decomp_q={decomp_q} buf_q={buf_q}{consumer_str}",
self.sample_idx, snap.rss_str, snap.post_collect_str, snap.collected_str,
);
}
self.sample_idx += 1;
self.next_threshold = self.next_threshold.saturating_add(Self::SAMPLE_INTERVAL_RECORDS);
}
}
impl Default for MergeProbe {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fmt_bytes_units() {
assert_eq!(fmt_bytes(0), "0 B");
assert_eq!(fmt_bytes(512), "512 B");
assert_eq!(fmt_bytes(1024), "1.0 KiB");
assert_eq!(fmt_bytes(1536), "1.5 KiB");
assert_eq!(fmt_bytes(1024 * 1024), "1.0 MiB");
assert_eq!(fmt_bytes(1024 * 1024 * 1024), "1.0 GiB");
assert_eq!(fmt_bytes(2u64 * 1024 * 1024 * 1024 + 512 * 1024 * 1024), "2.5 GiB");
}
#[test]
fn test_process_rss_bytes_returns_plausible_value() {
let rss = process_rss_bytes().expect("RSS sampling should work on supported platforms");
assert!(rss >= 1024 * 1024, "RSS {rss} is implausibly small");
assert!(rss < 1024u64 * 1024 * 1024 * 1024, "RSS {rss} is implausibly large");
}
#[test]
fn test_log_snapshot_does_not_panic_without_logger() {
log_snapshot("test.label", 1024);
log_snapshot("test.label", 0);
}
#[test]
fn test_spill_probe_increments() {
let mut probe = SpillProbe::new("test_phase");
assert_eq!(probe.spill_count(), 0);
let stats = BufferProbeStats { usage: 1024, capacity: 2048, records: 10, segments: 1 };
probe.pre_spill(stats, None);
probe.post_spill(None);
assert_eq!(probe.spill_count(), 1);
let stats2 = BufferProbeStats { usage: 2048, capacity: 4096, records: 20, segments: 1 };
probe.pre_spill(stats2, None);
probe.post_spill(None);
assert_eq!(probe.spill_count(), 2);
probe.phase1_end(0);
}
#[test]
fn test_merge_probe_samples_at_interval() {
let mut probe = MergeProbe::new();
assert_eq!(probe.sample_count(), 0);
probe.record(MergeProbe::SAMPLE_INTERVAL_RECORDS - 1);
assert_eq!(probe.sample_count(), 0);
probe.record(MergeProbe::SAMPLE_INTERVAL_RECORDS);
assert_eq!(probe.sample_count(), 1);
probe.record(MergeProbe::SAMPLE_INTERVAL_RECORDS + 1);
assert_eq!(probe.sample_count(), 1);
probe.record(MergeProbe::SAMPLE_INTERVAL_RECORDS * 2);
assert_eq!(probe.sample_count(), 2);
}
}