use std::cell::RefCell;
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use crossbeam_channel::{self, Receiver, Sender};
use parking_lot::Mutex;
use crate::txn_arena::KeyFingerprint;
const DEFAULT_BATCH_SIZE: usize = 64;
const MAX_FLUSH_INTERVAL_MS: u64 = 10;
#[derive(Debug)]
pub enum DirtyEvent {
Batch {
txn_id: u64,
keys: Vec<KeyFingerprint>,
},
AdvanceEpoch,
Shutdown,
}
struct ThreadLocalBuffer {
txn_id: u64,
keys: Vec<KeyFingerprint>,
sender: Sender<DirtyEvent>,
}
impl ThreadLocalBuffer {
fn new(sender: Sender<DirtyEvent>) -> Self {
Self {
txn_id: 0,
keys: Vec::with_capacity(DEFAULT_BATCH_SIZE),
sender,
}
}
#[inline]
fn mark_dirty(&mut self, txn_id: u64, key_fingerprint: KeyFingerprint) {
if self.txn_id != txn_id {
self.flush();
self.txn_id = txn_id;
}
self.keys.push(key_fingerprint);
}
fn flush(&mut self) {
if !self.keys.is_empty() {
let keys = std::mem::take(&mut self.keys);
let _ = self.sender.try_send(DirtyEvent::Batch {
txn_id: self.txn_id,
keys,
});
self.keys = Vec::with_capacity(DEFAULT_BATCH_SIZE);
}
}
}
pub struct BatchedDirtyTracker {
sender: Sender<DirtyEvent>,
receiver: Receiver<DirtyEvent>,
aggregator_handle: Mutex<Option<JoinHandle<()>>>,
running: AtomicBool,
current_epoch: AtomicU64,
epochs: [Mutex<HashSet<KeyFingerprint>>; 4],
stats: DirtyTrackingStats,
}
pub struct DirtyTrackingStats {
pub events_received: AtomicU64,
pub keys_tracked: AtomicU64,
pub batches_received: AtomicU64,
pub current_epoch: AtomicU64,
}
impl Default for DirtyTrackingStats {
fn default() -> Self {
Self {
events_received: AtomicU64::new(0),
keys_tracked: AtomicU64::new(0),
batches_received: AtomicU64::new(0),
current_epoch: AtomicU64::new(0),
}
}
}
const EPOCH_RING_SIZE: usize = 4;
impl BatchedDirtyTracker {
pub fn new() -> Arc<Self> {
let (sender, receiver) = crossbeam_channel::bounded(1024);
let tracker = Arc::new(Self {
sender,
receiver,
aggregator_handle: Mutex::new(None),
running: AtomicBool::new(false),
current_epoch: AtomicU64::new(0),
epochs: [
Mutex::new(HashSet::new()),
Mutex::new(HashSet::new()),
Mutex::new(HashSet::new()),
Mutex::new(HashSet::new()),
],
stats: DirtyTrackingStats::default(),
});
tracker
}
pub fn start(self: &Arc<Self>) {
if self.running.swap(true, Ordering::SeqCst) {
return; }
let tracker = Arc::clone(self);
let handle = thread::spawn(move || {
tracker.aggregator_loop();
});
*self.aggregator_handle.lock() = Some(handle);
}
pub fn stop(&self) {
if !self.running.swap(false, Ordering::SeqCst) {
return; }
let _ = self.sender.send(DirtyEvent::Shutdown);
if let Some(handle) = self.aggregator_handle.lock().take() {
let _ = handle.join();
}
}
pub fn get_sender(&self) -> Sender<DirtyEvent> {
self.sender.clone()
}
#[inline]
pub fn mark_dirty(&self, txn_id: u64, key_fingerprint: KeyFingerprint) {
thread_local! {
static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
}
BUFFER.with(|cell| {
let mut buffer = cell.borrow_mut();
if buffer.is_none() {
*buffer = Some(ThreadLocalBuffer::new(self.sender.clone()));
}
buffer.as_mut().unwrap().mark_dirty(txn_id, key_fingerprint);
});
}
pub fn flush_thread_buffer(&self) {
thread_local! {
static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
}
BUFFER.with(|cell| {
if let Some(buffer) = cell.borrow_mut().as_mut() {
buffer.flush();
}
});
}
#[inline]
pub fn send_batch(&self, txn_id: u64, keys: Vec<KeyFingerprint>) {
if keys.is_empty() {
return;
}
let _ = self.sender.try_send(DirtyEvent::Batch { txn_id, keys });
}
pub fn advance_epoch(&self) -> (u64, Vec<KeyFingerprint>) {
let _ = self.sender.try_send(DirtyEvent::AdvanceEpoch);
let old_epoch = self.current_epoch.fetch_add(1, Ordering::SeqCst);
let old_idx = (old_epoch as usize) % EPOCH_RING_SIZE;
let mut guard = self.epochs[old_idx].lock();
let keys: Vec<_> = guard.drain().collect();
self.stats.current_epoch.store(old_epoch + 1, Ordering::Relaxed);
(old_epoch, keys)
}
pub fn current_epoch(&self) -> u64 {
self.current_epoch.load(Ordering::Relaxed)
}
pub fn stats(&self) -> &DirtyTrackingStats {
&self.stats
}
fn aggregator_loop(&self) {
use crossbeam_channel::RecvTimeoutError;
let timeout = std::time::Duration::from_millis(MAX_FLUSH_INTERVAL_MS);
while self.running.load(Ordering::Relaxed) {
match self.receiver.recv_timeout(timeout) {
Ok(event) => {
self.process_event(event);
}
Err(RecvTimeoutError::Timeout) => {
}
Err(RecvTimeoutError::Disconnected) => {
break;
}
}
}
while let Ok(event) = self.receiver.try_recv() {
if matches!(event, DirtyEvent::Shutdown) {
break;
}
self.process_event(event);
}
}
fn process_event(&self, event: DirtyEvent) {
match event {
DirtyEvent::Batch { txn_id: _, keys } => {
let epoch = self.current_epoch.load(Ordering::Relaxed);
let idx = (epoch as usize) % EPOCH_RING_SIZE;
let mut guard = self.epochs[idx].lock();
let key_count = keys.len();
guard.extend(keys);
self.stats.events_received.fetch_add(1, Ordering::Relaxed);
self.stats.keys_tracked.fetch_add(key_count as u64, Ordering::Relaxed);
self.stats.batches_received.fetch_add(1, Ordering::Relaxed);
}
DirtyEvent::AdvanceEpoch => {
}
DirtyEvent::Shutdown => {
}
}
}
}
impl Default for BatchedDirtyTracker {
fn default() -> Self {
let (sender, receiver) = crossbeam_channel::bounded(1024);
Self {
sender,
receiver,
aggregator_handle: Mutex::new(None),
running: AtomicBool::new(false),
current_epoch: AtomicU64::new(0),
epochs: [
Mutex::new(HashSet::new()),
Mutex::new(HashSet::new()),
Mutex::new(HashSet::new()),
Mutex::new(HashSet::new()),
],
stats: DirtyTrackingStats::default(),
}
}
}
impl Drop for BatchedDirtyTracker {
fn drop(&mut self) {
self.stop();
}
}
pub struct TxnDirtyBuffer {
txn_id: u64,
keys: Vec<KeyFingerprint>,
}
impl TxnDirtyBuffer {
#[inline]
pub fn new(txn_id: u64) -> Self {
Self {
txn_id,
keys: Vec::with_capacity(64),
}
}
#[inline]
pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
Self {
txn_id,
keys: Vec::with_capacity(capacity),
}
}
#[inline]
pub fn record(&mut self, key_fingerprint: KeyFingerprint) {
self.keys.push(key_fingerprint);
}
#[inline]
pub fn record_many(&mut self, key_fingerprints: impl IntoIterator<Item = KeyFingerprint>) {
self.keys.extend(key_fingerprints);
}
#[inline]
pub fn txn_id(&self) -> u64 {
self.txn_id
}
#[inline]
pub fn len(&self) -> usize {
self.keys.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[inline]
pub fn drain(&mut self) -> Vec<KeyFingerprint> {
std::mem::take(&mut self.keys)
}
#[inline]
pub fn flush_to(&mut self, tracker: &BatchedDirtyTracker) {
if !self.keys.is_empty() {
tracker.send_batch(self.txn_id, std::mem::take(&mut self.keys));
self.keys = Vec::with_capacity(64);
}
}
#[inline]
pub fn clear(&mut self) {
self.keys.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_txn_dirty_buffer() {
let mut buffer = TxnDirtyBuffer::new(1);
buffer.record(KeyFingerprint::from_bytes(b"key1"));
buffer.record(KeyFingerprint::from_bytes(b"key2"));
buffer.record(KeyFingerprint::from_bytes(b"key3"));
assert_eq!(buffer.len(), 3);
let keys = buffer.drain();
assert_eq!(keys.len(), 3);
assert!(buffer.is_empty());
}
#[test]
fn test_batched_tracker_basic() {
let tracker = BatchedDirtyTracker::new();
tracker.start();
tracker.send_batch(1, vec![
KeyFingerprint::from_bytes(b"key1"),
KeyFingerprint::from_bytes(b"key2"),
]);
thread::sleep(Duration::from_millis(50));
let (_epoch, keys) = tracker.advance_epoch();
assert!(tracker.stats().batches_received.load(Ordering::Relaxed) >= 1);
tracker.stop();
}
#[test]
fn test_epoch_rotation() {
let tracker = BatchedDirtyTracker::new();
{
let mut guard = tracker.epochs[0].lock();
guard.insert(KeyFingerprint::from_bytes(b"key1"));
guard.insert(KeyFingerprint::from_bytes(b"key2"));
}
let (epoch, keys) = tracker.advance_epoch();
assert_eq!(epoch, 0);
assert_eq!(keys.len(), 2);
let (epoch2, keys2) = tracker.advance_epoch();
assert_eq!(epoch2, 1);
assert!(keys2.is_empty());
}
}