use std::cell::RefCell;
use std::sync::atomic::{AtomicUsize, Ordering};
thread_local! {
static BATCH_DEPTH: RefCell<usize> = const { RefCell::new(0) };
static PENDING_UPDATES: RefCell<Vec<Box<dyn FnOnce()>>> = const { RefCell::new(Vec::new()) };
}
static BATCH_COUNTER: AtomicUsize = AtomicUsize::new(0);
pub fn batch<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
start_batch();
let result = f();
end_batch();
result
}
pub fn start_batch() {
BATCH_DEPTH.with(|depth| {
*depth.borrow_mut() += 1;
});
BATCH_COUNTER.fetch_add(1, Ordering::Relaxed);
}
pub fn end_batch() {
BATCH_DEPTH.with(|depth| {
let mut d = depth.borrow_mut();
*d = d.saturating_sub(1);
if *d == 0 {
flush_updates();
}
});
}
pub fn is_batching() -> bool {
batch_depth() > 0
}
pub fn batch_depth() -> usize {
BATCH_DEPTH.with(|depth| *depth.borrow())
}
pub fn batch_count() -> usize {
BATCH_COUNTER.load(Ordering::Relaxed)
}
pub fn flush() {
flush_updates();
}
pub fn queue_update<F: FnOnce() + 'static>(f: F) {
if is_batching() {
PENDING_UPDATES.with(|updates| {
updates.borrow_mut().push(Box::new(f));
});
} else {
f();
}
}
pub fn pending_count() -> usize {
PENDING_UPDATES.with(|updates| updates.borrow().len())
}
fn flush_updates() {
PENDING_UPDATES.with(|updates| {
let pending: Vec<_> = updates.borrow_mut().drain(..).collect();
for update in pending {
update();
}
});
}
pub struct Transaction {
updates: Vec<Box<dyn FnOnce()>>,
committed: bool,
}
impl Transaction {
pub fn new() -> Self {
Self {
updates: Vec::new(),
committed: false,
}
}
pub fn update<F: FnOnce() + 'static>(&mut self, f: F) {
self.updates.push(Box::new(f));
}
pub fn commit(mut self) {
self.committed = true;
batch(|| {
for update in self.updates.drain(..) {
update();
}
});
}
pub fn rollback(mut self) {
self.updates.clear();
}
pub fn is_empty(&self) -> bool {
self.updates.is_empty()
}
pub fn len(&self) -> usize {
self.updates.len()
}
}
impl Default for Transaction {
fn default() -> Self {
Self::new()
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.committed && !self.updates.is_empty() {
#[cfg(debug_assertions)]
eprintln!(
"Warning: Transaction dropped without commit ({} updates discarded)",
self.updates.len()
);
}
}
}
pub struct BatchGuard {
_private: (),
}
impl BatchGuard {
pub fn new() -> Self {
start_batch();
Self { _private: () }
}
}
impl Default for BatchGuard {
fn default() -> Self {
Self::new()
}
}
impl Drop for BatchGuard {
fn drop(&mut self) {
end_batch();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[test]
fn test_batch_basic() {
let mut executed = false;
batch(|| {
executed = true;
});
assert!(executed);
}
#[test]
fn test_batch_return_value() {
let result = batch(|| 42);
assert_eq!(result, 42);
}
#[test]
fn test_batch_nested() {
let mut depth = 0;
batch(|| {
depth = batch_depth();
assert!(depth >= 1);
batch(|| {
let inner_depth = batch_depth();
assert!(inner_depth > depth);
});
});
}
#[test]
fn test_start_end_batch() {
assert_eq!(batch_depth(), 0);
assert!(!is_batching());
start_batch();
assert_eq!(batch_depth(), 1);
assert!(is_batching());
end_batch();
assert_eq!(batch_depth(), 0);
assert!(!is_batching());
}
#[test]
fn test_nested_start_end_batch() {
start_batch();
assert_eq!(batch_depth(), 1);
start_batch();
assert_eq!(batch_depth(), 2);
end_batch();
assert_eq!(batch_depth(), 1);
end_batch();
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_batch_depth_initial() {
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_batch_depth_single_batch() {
batch(|| {
assert_eq!(batch_depth(), 1);
});
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_batch_depth_nested_batches() {
batch(|| {
assert_eq!(batch_depth(), 1);
batch(|| {
assert_eq!(batch_depth(), 2);
batch(|| {
assert_eq!(batch_depth(), 3);
});
assert_eq!(batch_depth(), 2);
});
assert_eq!(batch_depth(), 1);
});
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_is_batching_false_initially() {
assert!(!is_batching());
}
#[test]
fn test_is_batching_true_in_batch() {
batch(|| {
assert!(is_batching());
});
assert!(!is_batching());
}
#[test]
fn test_batch_count_increments() {
let count_before = batch_count();
batch(|| {});
assert!(batch_count() > count_before);
}
#[test]
fn test_flush_does_not_panic() {
flush();
start_batch();
flush();
end_batch();
}
#[test]
fn test_queue_update_outside_batch() {
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
queue_update(move || {
executed_clone.store(true, Ordering::SeqCst);
});
assert!(
executed.load(Ordering::SeqCst),
"Should execute immediately when not batching"
);
}
#[test]
fn test_queue_update_inside_batch() {
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
batch(|| {
queue_update(move || {
executed_clone.store(true, Ordering::SeqCst);
});
assert!(!executed.load(Ordering::SeqCst));
});
assert!(executed.load(Ordering::SeqCst));
}
#[test]
fn test_pending_count_outside_batch() {
assert_eq!(pending_count(), 0);
}
#[test]
fn test_pending_count_inside_batch() {
batch(|| {
queue_update(|| {});
assert_eq!(pending_count(), 1);
queue_update(|| {});
assert_eq!(pending_count(), 2);
flush();
assert_eq!(pending_count(), 0);
});
}
#[test]
fn test_transaction_new() {
let tx = Transaction::new();
assert!(tx.is_empty());
assert_eq!(tx.len(), 0);
}
#[test]
fn test_transaction_default() {
let tx = Transaction::default();
assert!(tx.is_empty());
}
#[test]
fn test_transaction_update() {
let mut tx = Transaction::new();
assert_eq!(tx.len(), 0);
tx.update(|| {});
assert_eq!(tx.len(), 1);
tx.update(|| {});
assert_eq!(tx.len(), 2);
assert!(!tx.is_empty());
}
#[test]
fn test_transaction_commit() {
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
let mut tx = Transaction::new();
tx.update(move || {
executed_clone.store(true, Ordering::SeqCst);
});
assert!(!executed.load(Ordering::SeqCst));
tx.commit();
assert!(executed.load(Ordering::SeqCst));
}
#[test]
fn test_transaction_rollback() {
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
let mut tx = Transaction::new();
tx.update(move || {
executed_clone.store(true, Ordering::SeqCst);
});
tx.rollback();
assert!(
!executed.load(Ordering::SeqCst),
"Updates should be discarded"
);
}
#[test]
fn test_transaction_len() {
let mut tx = Transaction::new();
assert_eq!(tx.len(), 0);
tx.update(|| {});
tx.update(|| {});
tx.update(|| {});
assert_eq!(tx.len(), 3);
}
#[test]
fn test_transaction_is_empty() {
let mut tx = Transaction::new();
assert!(tx.is_empty());
tx.update(|| {});
assert!(!tx.is_empty());
}
#[test]
fn test_transaction_commit_empties() {
let mut tx = Transaction::new();
tx.update(|| {});
tx.update(|| {});
assert_eq!(tx.len(), 2);
tx.commit();
}
#[test]
fn test_batch_guard_new() {
assert_eq!(batch_depth(), 0);
{
let _guard = BatchGuard::new();
assert_eq!(batch_depth(), 1);
assert!(is_batching());
}
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_batch_guard_default() {
assert_eq!(batch_depth(), 0);
{
let _guard = BatchGuard::default();
assert_eq!(batch_depth(), 1);
}
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_batch_guard_nested() {
assert_eq!(batch_depth(), 0);
{
let _guard1 = BatchGuard::new();
assert_eq!(batch_depth(), 1);
{
let _guard2 = BatchGuard::new();
assert_eq!(batch_depth(), 2);
}
assert_eq!(batch_depth(), 1);
}
assert_eq!(batch_depth(), 0);
}
#[test]
fn test_batch_with_queue_update() {
use std::sync::Mutex;
let results = Arc::new(Mutex::new(Vec::new()));
batch(|| {
let r1 = results.clone();
queue_update(move || {
r1.lock().unwrap().push(1);
});
let r2 = results.clone();
queue_update(move || {
r2.lock().unwrap().push(2);
});
let r3 = results.clone();
queue_update(move || {
r3.lock().unwrap().push(3);
});
});
let results_vec = results.lock().unwrap();
assert_eq!(results_vec.len(), 3);
}
#[test]
fn test_flush_inside_batch() {
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
batch(|| {
queue_update(move || {
executed_clone.store(true, Ordering::SeqCst);
});
assert!(!executed.load(Ordering::SeqCst));
flush();
assert!(executed.load(Ordering::SeqCst));
});
}
#[test]
fn test_transaction_commit_in_batch() {
let executed = Arc::new(AtomicBool::new(false));
let executed_clone = executed.clone();
batch(|| {
let mut tx = Transaction::new();
tx.update(move || {
executed_clone.store(true, Ordering::SeqCst);
});
tx.commit();
});
assert!(executed.load(Ordering::SeqCst));
}
}