#![allow(missing_docs)]
use crate::types::{RegionId, TaskId};
use crate::util::CachePadded;
use crossbeam_queue::SegQueue;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct EpochConfig {
pub advance_interval: Duration,
pub max_cleanup_delay: Duration,
pub auto_advance_on_pressure: bool,
pub pressure_threshold: usize,
}
impl Default for EpochConfig {
fn default() -> Self {
Self {
advance_interval: Duration::from_millis(100),
max_cleanup_delay: Duration::from_secs(1),
auto_advance_on_pressure: true,
pressure_threshold: 10 * 1024 * 1024, }
}
}
#[derive(Debug, Clone)]
pub struct EpochStats {
pub current_epoch: u64,
pub advance_count: u64,
pub last_advance: Instant,
pub active_pins: usize,
pub min_pinned_epoch: u64,
}
#[derive(Debug, Clone)]
pub struct CleanupQueueStats {
pub pending_count: usize,
pub enqueue_count: u64,
pub execute_count: u64,
pub memory_usage: usize,
pub total_cleanup_time: Duration,
}
#[derive(Debug, Default)]
pub struct LocalEpochStats {
pub pin_count: AtomicU64,
pub unpin_count: AtomicU64,
pub max_pin_duration: AtomicU64,
pub pin_start: AtomicU64,
}
#[allow(dead_code)] pub struct GlobalEpochCounter {
epoch: AtomicU64,
last_advance: AtomicU64, advance_interval: Duration,
advance_count: AtomicU64,
config: EpochConfig,
}
impl GlobalEpochCounter {
#[must_use]
pub fn new(config: EpochConfig) -> Self {
Self {
epoch: AtomicU64::new(1), last_advance: AtomicU64::new(0),
advance_interval: config.advance_interval,
advance_count: AtomicU64::new(0),
config,
}
}
#[inline]
pub fn current_epoch(&self) -> u64 {
self.epoch.load(Ordering::Acquire)
}
pub fn try_advance(&self) -> Option<u64> {
let now = Instant::now();
self.advance_epoch(now)
}
pub fn force_advance(&self) -> u64 {
self.advance_epoch(Instant::now())
.unwrap_or_else(|| self.current_epoch())
}
pub fn should_advance_on_pressure(&self, memory_usage: usize) -> bool {
self.config.auto_advance_on_pressure && memory_usage >= self.config.pressure_threshold
}
pub fn stats(&self) -> EpochStats {
EpochStats {
current_epoch: self.current_epoch(),
advance_count: self.advance_count.load(Ordering::Relaxed),
last_advance: Instant::now(), active_pins: 0, min_pinned_epoch: 0, }
}
#[cold]
fn advance_epoch(&self, _now: Instant) -> Option<u64> {
let new_epoch = self.epoch.fetch_add(1, Ordering::AcqRel) + 1;
self.advance_count.fetch_add(1, Ordering::Relaxed);
Some(new_epoch)
}
}
impl std::fmt::Debug for GlobalEpochCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GlobalEpochCounter")
.field("epoch", &self.current_epoch())
.field("advance_count", &self.advance_count.load(Ordering::Relaxed))
.finish()
}
}
pub struct LocalEpochPin {
pinned_epoch: AtomicU64,
is_active: AtomicBool,
thread_id: u64,
global: Arc<GlobalEpochCounter>,
stats: LocalEpochStats,
}
impl LocalEpochPin {
pub fn new(global: Arc<GlobalEpochCounter>) -> Self {
Self {
pinned_epoch: AtomicU64::new(0),
is_active: AtomicBool::new(false),
thread_id: 0, global,
stats: LocalEpochStats::default(),
}
}
#[inline]
pub fn pin(&self) -> EpochGuard<'_> {
let epoch = self.global.current_epoch();
self.pinned_epoch.store(epoch, Ordering::Release);
self.is_active.store(true, Ordering::Release);
self.stats.pin_count.fetch_add(1, Ordering::Relaxed);
self.stats.pin_start.store(0u64, Ordering::Relaxed);
EpochGuard { pin: self, epoch }
}
#[inline]
pub fn pinned_epoch(&self) -> Option<u64> {
if self.is_active.load(Ordering::Acquire) {
Some(self.pinned_epoch.load(Ordering::Acquire))
} else {
None
}
}
pub fn thread_id(&self) -> u64 {
self.thread_id
}
pub fn stats(&self) -> LocalEpochStats {
LocalEpochStats {
pin_count: AtomicU64::new(self.stats.pin_count.load(Ordering::Relaxed)),
unpin_count: AtomicU64::new(self.stats.unpin_count.load(Ordering::Relaxed)),
max_pin_duration: AtomicU64::new(self.stats.max_pin_duration.load(Ordering::Relaxed)),
pin_start: AtomicU64::new(self.stats.pin_start.load(Ordering::Relaxed)),
}
}
#[inline]
fn unpin(&self) {
self.is_active.store(false, Ordering::Release);
self.stats.unpin_count.fetch_add(1, Ordering::Relaxed);
let now = 0u64; let start = self.stats.pin_start.load(Ordering::Relaxed);
if start > 0 {
let duration = now - start;
let _ = self
.stats
.max_pin_duration
.fetch_max(duration, Ordering::Relaxed);
}
}
}
impl std::fmt::Debug for LocalEpochPin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalEpochPin")
.field("thread_id", &self.thread_id)
.field("pinned_epoch", &self.pinned_epoch())
.field("is_active", &self.is_active.load(Ordering::Relaxed))
.finish()
}
}
pub struct EpochGuard<'a> {
pin: &'a LocalEpochPin,
epoch: u64,
}
impl EpochGuard<'_> {
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch
}
}
impl Drop for EpochGuard<'_> {
fn drop(&mut self) {
self.pin.unpin();
}
}
pub struct SafePointDetector {
thread_pins: Vec<CachePadded<Arc<LocalEpochPin>>>,
min_observed_epoch: AtomicU64,
last_safe_point: AtomicU64,
global: Arc<GlobalEpochCounter>,
stats: SafePointStats,
}
#[derive(Debug, Default)]
struct SafePointStats {
check_count: AtomicU64,
safe_point_count: AtomicU64,
check_time_ns: AtomicU64,
}
impl SafePointDetector {
pub fn new(global: Arc<GlobalEpochCounter>) -> Self {
Self {
thread_pins: Vec::new(),
min_observed_epoch: AtomicU64::new(1),
last_safe_point: AtomicU64::new(1),
global,
stats: SafePointStats::default(),
}
}
pub fn is_safe_point(&self, epoch: u64) -> bool {
let start = Instant::now();
let result = self.compute_safe_point() >= epoch;
self.stats.check_count.fetch_add(1, Ordering::Relaxed);
if result {
self.stats.safe_point_count.fetch_add(1, Ordering::Relaxed);
}
let elapsed = start.elapsed().as_nanos() as u64;
self.stats
.check_time_ns
.fetch_add(elapsed, Ordering::Relaxed);
result
}
#[cold]
fn compute_safe_point(&self) -> u64 {
let mut min_epoch = u64::MAX;
for pin in &self.thread_pins {
if let Some(pinned) = pin.pinned_epoch() {
min_epoch = min_epoch.min(pinned);
}
}
if min_epoch == u64::MAX {
let current = self.global.current_epoch();
min_epoch = if current > 1 { current - 1 } else { 1 };
}
self.min_observed_epoch.store(min_epoch, Ordering::Release);
self.last_safe_point.store(min_epoch, Ordering::Release);
min_epoch
}
#[inline]
pub fn cached_safe_point(&self) -> u64 {
self.last_safe_point.load(Ordering::Acquire)
}
pub fn register_pin(&mut self, pin: Arc<LocalEpochPin>) {
self.thread_pins.push(CachePadded::new(pin));
}
pub fn stats(&self) -> (u64, u64, Duration, usize) {
let check_count = self.stats.check_count.load(Ordering::Relaxed);
let safe_point_count = self.stats.safe_point_count.load(Ordering::Relaxed);
let check_time = Duration::from_nanos(self.stats.check_time_ns.load(Ordering::Relaxed));
let active_pins = self
.thread_pins
.iter()
.map(|pin| usize::from(pin.is_active.load(Ordering::Relaxed)))
.sum();
(check_count, safe_point_count, check_time, active_pins)
}
}
impl std::fmt::Debug for SafePointDetector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SafePointDetector")
.field("thread_count", &self.thread_pins.len())
.field("last_safe_point", &self.cached_safe_point())
.finish()
}
}
#[derive(Debug, Clone)]
pub enum CleanupWork {
Obligation { id: u64, metadata: Vec<u8> },
WakerCleanup { waker_id: u64, source: String },
RegionCleanup {
region_id: RegionId,
task_ids: Vec<TaskId>,
},
TimerCleanup { timer_id: u64, timer_type: String },
ChannelCleanup {
channel_id: u64,
cleanup_type: String,
data: Vec<u8>,
},
}
impl CleanupWork {
#[must_use]
pub fn memory_usage(&self) -> usize {
std::mem::size_of::<Self>()
+ match self {
Self::Obligation { metadata, .. } => metadata.len(),
Self::WakerCleanup { source, .. } => source.len(),
Self::RegionCleanup { task_ids, .. } => {
task_ids.len() * std::mem::size_of::<TaskId>()
}
Self::TimerCleanup { timer_type, .. } => timer_type.len(),
Self::ChannelCleanup {
cleanup_type, data, ..
} => cleanup_type.len() + data.len(),
}
}
}
pub struct CleanupEntry {
pub epoch: u64,
pub cleanup_fn: Box<dyn FnOnce() + Send + 'static>,
pub debug_info: Option<String>,
pub memory_usage: usize,
}
impl std::fmt::Debug for CleanupEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CleanupEntry")
.field("epoch", &self.epoch)
.field("memory_usage", &self.memory_usage)
.field("debug_info", &self.debug_info)
.finish()
}
}
pub struct DeferredCleanupQueue {
queue: SegQueue<CleanupEntry>,
enqueue_count: AtomicU64,
execute_count: AtomicU64,
cleanup_time_ns: AtomicU64,
memory_usage: AtomicUsize,
global: Arc<GlobalEpochCounter>,
max_queue_size: AtomicUsize,
}
impl DeferredCleanupQueue {
pub fn new(global: Arc<GlobalEpochCounter>, max_size: usize) -> Self {
Self {
queue: SegQueue::new(),
enqueue_count: AtomicU64::new(0),
execute_count: AtomicU64::new(0),
cleanup_time_ns: AtomicU64::new(0),
memory_usage: AtomicUsize::new(0),
global,
max_queue_size: AtomicUsize::new(max_size),
}
}
pub fn defer_cleanup<F>(&self, cleanup_fn: F) -> Result<(), CleanupEntry>
where
F: FnOnce() + Send + 'static,
{
let epoch = self.global.current_epoch();
let entry = CleanupEntry {
epoch,
cleanup_fn: Box::new(cleanup_fn),
debug_info: None,
memory_usage: std::mem::size_of::<CleanupEntry>(),
};
let current_size = self.queue.len();
let max_size = self.max_queue_size.load(Ordering::Relaxed);
if current_size >= max_size {
return Err(entry);
}
self.queue.push(entry);
self.enqueue_count.fetch_add(1, Ordering::Relaxed);
self.memory_usage
.fetch_add(std::mem::size_of::<CleanupEntry>(), Ordering::Relaxed);
Ok(())
}
pub fn defer_cleanup_with_debug<F>(
&self,
cleanup_fn: F,
debug_info: String,
) -> Result<(), CleanupEntry>
where
F: FnOnce() + Send + 'static,
{
let epoch = self.global.current_epoch();
let memory_usage = std::mem::size_of::<CleanupEntry>() + debug_info.len();
let entry = CleanupEntry {
epoch,
cleanup_fn: Box::new(cleanup_fn),
debug_info: Some(debug_info),
memory_usage,
};
let current_size = self.queue.len();
let max_size = self.max_queue_size.load(Ordering::Relaxed);
if current_size >= max_size {
return Err(entry);
}
self.queue.push(entry);
self.enqueue_count.fetch_add(1, Ordering::Relaxed);
self.memory_usage.fetch_add(memory_usage, Ordering::Relaxed);
Ok(())
}
pub fn execute_safe_cleanups(&self, safe_epoch: u64) -> usize {
let start = Instant::now();
let mut executed = 0;
let mut entries_to_requeue = Vec::new();
while let Some(entry) = self.queue.pop() {
if entry.epoch <= safe_epoch {
(entry.cleanup_fn)();
executed += 1;
self.memory_usage
.fetch_sub(entry.memory_usage, Ordering::Relaxed);
} else {
entries_to_requeue.push(entry);
}
}
for entry in entries_to_requeue {
self.queue.push(entry);
}
if executed > 0 {
self.execute_count
.fetch_add(executed as u64, Ordering::Relaxed);
let elapsed = start.elapsed().as_nanos() as u64;
self.cleanup_time_ns.fetch_add(elapsed, Ordering::Relaxed);
}
executed
}
pub fn stats(&self) -> CleanupQueueStats {
CleanupQueueStats {
pending_count: self.queue.len(),
enqueue_count: self.enqueue_count.load(Ordering::Relaxed),
execute_count: self.execute_count.load(Ordering::Relaxed),
memory_usage: self.memory_usage.load(Ordering::Relaxed),
total_cleanup_time: Duration::from_nanos(self.cleanup_time_ns.load(Ordering::Relaxed)),
}
}
pub fn is_near_capacity(&self) -> bool {
let current = self.queue.len();
let max = self.max_queue_size.load(Ordering::Relaxed);
current >= (max * 8) / 10 }
pub fn set_max_size(&self, max_size: usize) {
self.max_queue_size.store(max_size, Ordering::Relaxed);
}
}
impl std::fmt::Debug for DeferredCleanupQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeferredCleanupQueue")
.field("pending_count", &self.queue.len())
.field("memory_usage", &self.memory_usage.load(Ordering::Relaxed))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
#[test]
fn test_epoch_advancement() {
let config = EpochConfig {
advance_interval: Duration::from_millis(1),
..EpochConfig::default()
};
let epoch = Arc::new(GlobalEpochCounter::new(config));
let initial = epoch.current_epoch();
assert_eq!(epoch.try_advance(), None);
thread::sleep(Duration::from_millis(2));
let advanced = epoch.try_advance();
assert!(advanced.is_some());
assert_eq!(advanced.unwrap(), initial + 1);
let new_epoch = epoch.force_advance();
assert_eq!(new_epoch, initial + 2);
}
#[test]
fn test_epoch_pin() {
let epoch = Arc::new(GlobalEpochCounter::new(EpochConfig::default()));
let pin = LocalEpochPin::new(epoch.clone());
assert_eq!(pin.pinned_epoch(), None);
let guard = pin.pin();
let pinned_epoch = guard.epoch();
assert_eq!(pin.pinned_epoch(), Some(pinned_epoch));
epoch.force_advance();
assert!(epoch.current_epoch() > pinned_epoch);
assert_eq!(pin.pinned_epoch(), Some(pinned_epoch));
drop(guard);
assert_eq!(pin.pinned_epoch(), None);
}
#[test]
fn test_safe_point_detection() {
let epoch = Arc::new(GlobalEpochCounter::new(EpochConfig::default()));
let mut detector = SafePointDetector::new(epoch.clone());
let pin = Arc::new(LocalEpochPin::new(epoch.clone()));
detector.register_pin(pin.clone());
let initial_epoch = epoch.current_epoch();
let safe_point = detector.compute_safe_point();
assert!(safe_point <= initial_epoch);
let _guard = pin.pin();
let pinned_epoch = pin.pinned_epoch().unwrap();
epoch.force_advance();
epoch.force_advance();
let safe_point = detector.compute_safe_point();
assert!(safe_point <= pinned_epoch);
drop(_guard);
let safe_point_after = detector.compute_safe_point();
assert!(safe_point_after > safe_point);
}
#[test]
fn test_cleanup_queue() {
let epoch = Arc::new(GlobalEpochCounter::new(EpochConfig::default()));
let queue = DeferredCleanupQueue::new(epoch.clone(), 1000);
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
queue
.defer_cleanup(move || {
executed_clone.store(true, Ordering::Relaxed);
})
.unwrap();
let count = queue.execute_safe_cleanups(epoch.current_epoch());
assert_eq!(count, 0);
assert!(!executed.load(Ordering::Relaxed));
epoch.force_advance();
let count = queue.execute_safe_cleanups(epoch.current_epoch());
assert_eq!(count, 1);
assert!(executed.load(Ordering::Relaxed));
}
#[test]
fn test_cleanup_ordering() {
let epoch = Arc::new(GlobalEpochCounter::new(EpochConfig::default()));
let queue = DeferredCleanupQueue::new(epoch.clone(), 1000);
let executed_order = Arc::new(std::sync::Mutex::new(Vec::new()));
for i in 0..5 {
let order = executed_order.clone();
queue
.defer_cleanup(move || {
order.lock().unwrap().push(i);
})
.unwrap();
if i % 2 == 0 {
epoch.force_advance();
}
}
let safe_epoch = epoch.force_advance();
let count = queue.execute_safe_cleanups(safe_epoch);
assert_eq!(count, 5);
let order = executed_order.lock().unwrap();
assert!(order.len() == 5);
}
#[test]
fn test_queue_backpressure() {
let epoch = Arc::new(GlobalEpochCounter::new(EpochConfig::default()));
let queue = DeferredCleanupQueue::new(epoch.clone(), 2);
let result1 = queue.defer_cleanup(|| {});
assert!(result1.is_ok());
let result2 = queue.defer_cleanup(|| {});
assert!(result2.is_ok());
let result3 = queue.defer_cleanup(|| {});
assert!(result3.is_err());
epoch.force_advance();
let executed = queue.execute_safe_cleanups(epoch.current_epoch());
assert!(executed > 0);
let result4 = queue.defer_cleanup(|| {});
assert!(result4.is_ok());
}
#[test]
fn test_concurrent_operations() {
let epoch = Arc::new(GlobalEpochCounter::new(EpochConfig {
advance_interval: Duration::from_millis(1),
..EpochConfig::default()
}));
let pin = Arc::new(LocalEpochPin::new(epoch.clone()));
let queue = Arc::new(DeferredCleanupQueue::new(epoch.clone(), 10000));
let mut handles = Vec::new();
for i in 0..4 {
let epoch = epoch.clone();
let pin = pin.clone();
let queue = queue.clone();
let handle = thread::spawn(move || {
for j in 0..100 {
let _guard = pin.pin();
let value = i * 100 + j;
let _ = queue.defer_cleanup(move || {
std::hint::black_box(value);
});
if j % 10 == 0 {
epoch.try_advance();
}
thread::yield_now();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let final_epoch = epoch.force_advance();
let executed = queue.execute_safe_cleanups(final_epoch);
assert!(executed > 0);
assert!(executed <= 400); }
}