use std::sync::atomic::{AtomicU64, Ordering};
pub fn prefetch_pages(ptr: *const u8, len: usize) {
if ptr.is_null() || len == 0 {
return;
}
let page_size = 4096;
let aligned_ptr = (ptr as usize & !(page_size - 1)) as *mut libc::c_void;
let aligned_len = (len + page_size - 1) & !(page_size - 1);
unsafe {
libc::madvise(aligned_ptr, aligned_len, libc::MADV_WILLNEED);
}
}
pub fn prefetch_batch(ranges: &[(*const u8, usize)]) {
for &(ptr, len) in ranges {
prefetch_pages(ptr, len);
}
}
pub fn advise_dontneed(ptr: *const u8, len: usize) {
if ptr.is_null() || len == 0 {
return;
}
let page_size = 4096;
let aligned_ptr = (ptr as usize & !(page_size - 1)) as *mut libc::c_void;
let aligned_len = (len + page_size - 1) & !(page_size - 1);
unsafe {
libc::madvise(aligned_ptr, aligned_len, libc::MADV_DONTNEED);
}
}
pub struct FaultCounter {
total_faults: AtomicU64,
last_snapshot: AtomicU64,
prefetch_issued: AtomicU64,
prefetch_misses: AtomicU64,
threshold: u64,
core_id: usize,
}
impl FaultCounter {
pub fn new(core_id: usize, threshold: u64) -> Self {
Self {
total_faults: AtomicU64::new(0),
last_snapshot: AtomicU64::new(0),
prefetch_issued: AtomicU64::new(0),
prefetch_misses: AtomicU64::new(0),
threshold,
core_id,
}
}
pub fn record_fault(&self) {
self.total_faults.fetch_add(1, Ordering::Relaxed);
}
pub fn record_prefetch(&self) {
self.prefetch_issued.fetch_add(1, Ordering::Relaxed);
}
pub fn record_miss(&self) {
self.prefetch_misses.fetch_add(1, Ordering::Relaxed);
}
pub fn is_degraded(&self) -> bool {
let total = self.total_faults.load(Ordering::Relaxed);
let last = self.last_snapshot.load(Ordering::Relaxed);
total.saturating_sub(last) >= self.threshold
}
pub fn snapshot(&self) {
let total = self.total_faults.load(Ordering::Relaxed);
self.last_snapshot.store(total, Ordering::Relaxed);
}
pub fn total_faults(&self) -> u64 {
self.total_faults.load(Ordering::Relaxed)
}
pub fn interval_faults(&self) -> u64 {
let total = self.total_faults.load(Ordering::Relaxed);
let last = self.last_snapshot.load(Ordering::Relaxed);
total.saturating_sub(last)
}
pub fn hit_rate(&self) -> f64 {
let issued = self.prefetch_issued.load(Ordering::Relaxed);
let misses = self.prefetch_misses.load(Ordering::Relaxed);
if issued == 0 {
1.0
} else {
1.0 - (misses as f64 / issued as f64)
}
}
pub fn miss_rate(&self) -> f64 {
1.0 - self.hit_rate()
}
pub fn prefetch_issued(&self) -> u64 {
self.prefetch_issued.load(Ordering::Relaxed)
}
pub fn prefetch_misses(&self) -> u64 {
self.prefetch_misses.load(Ordering::Relaxed)
}
pub fn core_id(&self) -> usize {
self.core_id
}
pub fn read_process_majflt() -> Option<u64> {
let stat = std::fs::read_to_string("/proc/self/stat").ok()?;
let fields: Vec<&str> = stat.split_whitespace().collect();
fields.get(11)?.parse::<u64>().ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fault_counter_basics() {
let counter = FaultCounter::new(0, 10);
assert!(!counter.is_degraded());
assert_eq!(counter.total_faults(), 0);
for _ in 0..10 {
counter.record_fault();
}
assert!(counter.is_degraded());
assert_eq!(counter.total_faults(), 10);
counter.snapshot();
assert!(!counter.is_degraded());
assert_eq!(counter.interval_faults(), 0);
}
#[test]
fn prefetch_hit_rate() {
let counter = FaultCounter::new(0, 100);
for _ in 0..100 {
counter.record_prefetch();
}
for _ in 0..10 {
counter.record_miss();
}
let rate = counter.hit_rate();
assert!((rate - 0.9).abs() < 0.01, "hit_rate: {rate}");
assert!((counter.miss_rate() - 0.1).abs() < 0.01);
}
#[test]
fn prefetch_pages_null_safe() {
prefetch_pages(std::ptr::null(), 0);
prefetch_pages(std::ptr::null(), 4096);
}
#[test]
fn read_majflt() {
#[cfg(target_os = "linux")]
{
let faults = FaultCounter::read_process_majflt();
assert!(faults.is_some());
}
}
}