use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use std::time::{Duration, Instant};
struct FSyncGroup {
inner: Mutex<FsyncGroupInner>,
condvar: Condvar,
}
struct FsyncGroupInner {
work_done: bool,
leader_exists: bool,
error: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum WaitStatus {
NoFsyncNeeded,
DoLeaderFsync,
DoTimeoutFsync,
}
impl FSyncGroup {
fn new() -> Arc<Self> {
Arc::new(FSyncGroup {
inner: Mutex::new(FsyncGroupInner {
work_done: false,
leader_exists: false,
error: None,
}),
condvar: Condvar::new(),
})
}
fn wait_for_event(&self, timeout: Duration) -> WaitStatus {
let mut inner = self.inner.lock().unwrap();
if inner.work_done {
return WaitStatus::NoFsyncNeeded;
}
let start = Instant::now();
loop {
let elapsed = start.elapsed();
if elapsed >= timeout {
return WaitStatus::DoTimeoutFsync;
}
let remaining = timeout - elapsed;
let (guard, _timed_out) =
self.condvar.wait_timeout(inner, remaining).unwrap();
inner = guard;
if inner.work_done {
return WaitStatus::NoFsyncNeeded;
}
if !inner.leader_exists {
inner.leader_exists = true;
return WaitStatus::DoLeaderFsync;
}
if start.elapsed() >= timeout {
return WaitStatus::DoTimeoutFsync;
}
}
}
fn wakeup_all(&self) {
let mut inner = self.inner.lock().unwrap();
inner.work_done = true;
inner.error = None;
drop(inner);
self.condvar.notify_all();
}
fn wakeup_all_with_error(&self, msg: String) {
let mut inner = self.inner.lock().unwrap();
inner.work_done = true;
inner.error = Some(msg);
drop(inner);
self.condvar.notify_all();
}
fn wakeup_one(&self) {
self.condvar.notify_one();
}
fn take_error(&self) -> Option<String> {
self.inner.lock().unwrap().error.clone()
}
}
struct FsyncState {
work_in_progress: bool,
next_fsync_waiters: Arc<FSyncGroup>,
num_next_waiters: usize,
start_next_wait: Option<Instant>,
}
pub struct FsyncManager {
grpc_threshold: usize,
grpc_interval_ms: u64,
grp_wait_on: bool,
fsync_timeout: Duration,
state: Mutex<FsyncState>,
leader_condvar: Condvar,
n_fsyncs: AtomicU64,
n_fsync_requests: AtomicU64,
n_fsync_timeouts: AtomicU64,
n_group_commits: AtomicU64,
fsync_time_ms: AtomicU64,
n_fsync_batch_size_sum: AtomicU64,
}
impl FsyncManager {
pub fn new(grpc_threshold: usize, grpc_interval_ms: u64) -> Self {
let grp_wait_on = grpc_threshold != 0 && grpc_interval_ms != 0;
FsyncManager {
grpc_threshold,
grpc_interval_ms,
grp_wait_on,
fsync_timeout: Duration::from_millis(500),
state: Mutex::new(FsyncState {
work_in_progress: false,
next_fsync_waiters: FSyncGroup::new(),
num_next_waiters: 0,
start_next_wait: None,
}),
leader_condvar: Condvar::new(),
n_fsyncs: AtomicU64::new(0),
n_fsync_requests: AtomicU64::new(0),
n_fsync_timeouts: AtomicU64::new(0),
n_group_commits: AtomicU64::new(0),
fsync_time_ms: AtomicU64::new(0),
n_fsync_batch_size_sum: AtomicU64::new(0),
}
}
pub fn fsync<F>(&self, do_fsync: F) -> std::io::Result<()>
where
F: Fn() -> std::io::Result<()>,
{
self.n_fsync_requests.fetch_add(1, Ordering::Relaxed);
let mut do_work = false;
let mut is_leader = false;
let mut leader_batch_size: u64 = 0;
let mut in_progress_group: Option<Arc<FSyncGroup>> = None;
let mut my_group: Option<Arc<FSyncGroup>> = None;
let mut need_to_wait = false;
{
let mut state = self.state.lock().unwrap();
if state.work_in_progress {
need_to_wait = true;
my_group = Some(Arc::clone(&state.next_fsync_waiters));
state.num_next_waiters += 1;
if self.grp_wait_on && state.num_next_waiters == 1 {
state.start_next_wait = Some(Instant::now());
}
if self.grp_wait_on
&& state.num_next_waiters >= self.grpc_threshold
{
self.leader_condvar.notify_one();
}
} else {
is_leader = true;
do_work = true;
state.work_in_progress = true;
if self.grp_wait_on {
state = self.grpc_wait(state);
}
leader_batch_size = state.num_next_waiters as u64;
in_progress_group = Some(Arc::clone(&state.next_fsync_waiters));
state.next_fsync_waiters = FSyncGroup::new();
state.num_next_waiters = 0;
}
}
if need_to_wait {
let group = my_group.as_ref().unwrap();
let wait_status = group.wait_for_event(self.fsync_timeout);
match wait_status {
WaitStatus::NoFsyncNeeded => {
if let Some(msg) = group.take_error() {
return Err(std::io::Error::other(msg));
}
return Ok(());
}
WaitStatus::DoLeaderFsync => {
let mut state = self.state.lock().unwrap();
if state.work_in_progress {
do_work = true;
} else {
is_leader = true;
do_work = true;
state.work_in_progress = true;
if self.grp_wait_on {
state = self.grpc_wait(state);
}
leader_batch_size = state.num_next_waiters as u64;
in_progress_group = my_group.take();
state.next_fsync_waiters = FSyncGroup::new();
state.num_next_waiters = 0;
}
}
WaitStatus::DoTimeoutFsync => {
do_work = true;
self.n_fsync_timeouts.fetch_add(1, Ordering::Relaxed);
}
}
}
if do_work {
self.n_fsyncs.fetch_add(1, Ordering::Relaxed);
let fsync_start = std::time::Instant::now();
let result = do_fsync();
let elapsed_ms = fsync_start.elapsed().as_millis() as u64;
self.fsync_time_ms.fetch_add(elapsed_ms, Ordering::Relaxed);
if is_leader && in_progress_group.is_some() {
self.n_group_commits.fetch_add(1, Ordering::Relaxed);
self.n_fsync_batch_size_sum
.fetch_add(leader_batch_size, Ordering::Relaxed);
}
if is_leader {
let in_prog = in_progress_group.as_ref().unwrap();
match &result {
Ok(()) => in_prog.wakeup_all(),
Err(e) => in_prog.wakeup_all_with_error(e.to_string()),
}
let mut state = self.state.lock().unwrap();
state.next_fsync_waiters.wakeup_one();
state.work_in_progress = false;
}
result
} else {
Ok(())
}
}
pub fn fsync_count(&self) -> u64 {
self.n_fsyncs.load(Ordering::Relaxed)
}
pub fn fsync_timeout_count(&self) -> u64 {
self.n_fsync_timeouts.load(Ordering::Relaxed)
}
pub fn group_commit_count(&self) -> u64 {
self.n_group_commits.load(Ordering::Relaxed)
}
pub fn fsync_time_ms(&self) -> u64 {
self.fsync_time_ms.load(Ordering::Relaxed)
}
pub fn fsync_batch_size_sum(&self) -> u64 {
self.n_fsync_batch_size_sum.load(Ordering::Relaxed)
}
pub fn fsync_request_count(&self) -> u64 {
self.n_fsync_requests.load(Ordering::Relaxed)
}
fn grpc_wait<'a>(
&'a self,
state: MutexGuard<'a, FsyncState>,
) -> MutexGuard<'a, FsyncState> {
if state.num_next_waiters == 0 {
return state;
}
if state.num_next_waiters < self.grpc_threshold {
let interval_ns = self.grpc_interval_ms as u128 * 1_000_000;
let elapsed_ns = state
.start_next_wait
.map(|t| t.elapsed().as_nanos())
.unwrap_or(0);
if elapsed_ns < interval_ns {
let remaining_ns = interval_ns - elapsed_ns;
let wait_dur = Duration::from_nanos(remaining_ns as u64);
let (new_guard, _) =
self.leader_condvar.wait_timeout(state, wait_dur).unwrap();
return new_guard;
}
}
state
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[test]
fn test_simple_fsync_no_grouping() {
let mgr = FsyncManager::new(0, 0);
let count = Arc::new(AtomicUsize::new(0));
let c = count.clone();
mgr.fsync(|| {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.unwrap();
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_multiple_threads_one_fsync() {
let mgr = Arc::new(FsyncManager::new(0, 0));
let fsync_count = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(std::sync::Barrier::new(3));
let mut handles = vec![];
for _ in 0..3 {
let mgr2 = Arc::clone(&mgr);
let fc = Arc::clone(&fsync_count);
let b = Arc::clone(&barrier);
handles.push(std::thread::spawn(move || {
b.wait();
mgr2.fsync(|| {
std::thread::sleep(Duration::from_millis(20));
fc.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.unwrap();
}));
}
for h in handles {
h.join().unwrap();
}
let total = fsync_count.load(Ordering::SeqCst);
assert!(
total < 3,
"expected coalescing (total < 3 fsyncs), got {}",
total
);
}
#[test]
fn test_fsync_error_propagated_to_waiters() {
let mgr = FsyncManager::new(0, 0);
let result =
mgr.fsync(|| Err(std::io::Error::other("simulated fsync failure")));
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("simulated fsync failure")
);
}
#[test]
fn test_grpc_threshold_respected() {
let mgr = Arc::new(FsyncManager::new(2, 50));
let fsync_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..4 {
let m = Arc::clone(&mgr);
let fc = Arc::clone(&fsync_count);
handles.push(std::thread::spawn(move || {
m.fsync(|| {
fc.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.unwrap();
}));
}
for h in handles {
h.join().unwrap();
}
let total = fsync_count.load(Ordering::SeqCst);
assert!(total >= 1, "at least one fsync must have run");
assert!(total <= 4, "unexpected fsync count: {}", total);
}
#[test]
fn test_sequential_calls_each_fsync_once() {
let mgr = FsyncManager::new(0, 0);
let count = Arc::new(AtomicUsize::new(0));
for _ in 0..5 {
let c = count.clone();
mgr.fsync(|| {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.unwrap();
}
assert_eq!(count.load(Ordering::SeqCst), 5);
}
#[test]
fn test_fsync_error_forwarded_to_waiting_threads() {
let mgr = Arc::new(FsyncManager::new(0, 0));
let barrier = Arc::new(std::sync::Barrier::new(2));
let mgr2 = Arc::clone(&mgr);
let b2 = Arc::clone(&barrier);
let leader = std::thread::spawn(move || {
b2.wait();
mgr2.fsync(|| {
std::thread::sleep(Duration::from_millis(30));
Err(std::io::Error::other("leader fail"))
})
});
barrier.wait();
std::thread::sleep(Duration::from_millis(2));
let waiter_result = mgr.fsync(|| {
Ok(())
});
let leader_result = leader.join().unwrap();
assert!(leader_result.is_err());
let _ = waiter_result; }
#[test]
fn test_returns_ok_on_success() {
let mgr = FsyncManager::new(0, 0);
assert!(mgr.fsync(|| Ok(())).is_ok());
}
#[test]
fn test_fsync_group_wakeup_all() {
let g = FSyncGroup::new();
g.wakeup_all();
assert!(g.inner.lock().unwrap().work_done);
assert!(g.take_error().is_none());
}
#[test]
fn test_fsync_group_wakeup_all_with_error() {
let g = FSyncGroup::new();
g.wakeup_all_with_error("oops".to_string());
assert!(g.inner.lock().unwrap().work_done);
assert_eq!(g.take_error().unwrap(), "oops");
}
#[test]
fn test_fsync_group_already_done() {
let g = FSyncGroup::new();
g.wakeup_all();
let status = g.wait_for_event(Duration::from_secs(5));
assert_eq!(status, WaitStatus::NoFsyncNeeded);
}
#[test]
fn test_fsync_group_becomes_leader_on_wakeup() {
let g = Arc::new(FSyncGroup::new());
let g2 = Arc::clone(&g);
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(10));
g2.wakeup_one();
});
let status = g.wait_for_event(Duration::from_millis(500));
assert_eq!(status, WaitStatus::DoLeaderFsync);
assert!(g.inner.lock().unwrap().leader_exists);
}
#[test]
fn test_fsync_group_timeout() {
let g = FSyncGroup::new();
g.inner.lock().unwrap().leader_exists = true;
let status = g.wait_for_event(Duration::from_millis(20));
assert_eq!(status, WaitStatus::DoTimeoutFsync);
}
#[test]
fn test_wait_status_variants_distinct() {
assert_ne!(WaitStatus::NoFsyncNeeded, WaitStatus::DoLeaderFsync);
assert_ne!(WaitStatus::NoFsyncNeeded, WaitStatus::DoTimeoutFsync);
assert_ne!(WaitStatus::DoLeaderFsync, WaitStatus::DoTimeoutFsync);
}
#[test]
fn test_grp_wait_on_requires_both_nonzero() {
let m1 = FsyncManager::new(0, 100);
assert!(!m1.grp_wait_on);
let m2 = FsyncManager::new(2, 0);
assert!(!m2.grp_wait_on);
let m3 = FsyncManager::new(2, 100);
assert!(m3.grp_wait_on);
}
#[test]
fn test_fsync_before_commit_invariant() {
use std::sync::atomic::AtomicU64;
const N_THREADS: usize = 8;
const OPS_PER_THREAD: usize = 200;
let next_lsn = Arc::new(AtomicU64::new(1));
let flushed_lsn = Arc::new(AtomicU64::new(0));
let snap_lsn = Arc::new(AtomicU64::new(0));
let mgr = Arc::new(FsyncManager::new(2, 5));
let barrier = Arc::new(std::sync::Barrier::new(N_THREADS));
let mut handles = vec![];
for _ in 0..N_THREADS {
let mgr2 = Arc::clone(&mgr);
let b = Arc::clone(&barrier);
let nl = Arc::clone(&next_lsn);
let fl = Arc::clone(&flushed_lsn);
let sl = Arc::clone(&snap_lsn);
handles.push(std::thread::spawn(move || {
b.wait();
for _ in 0..OPS_PER_THREAD {
let my_lsn =
nl.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let mut cur = sl.load(std::sync::atomic::Ordering::Relaxed);
while cur < my_lsn {
match sl.compare_exchange(
cur,
my_lsn,
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::Relaxed,
) {
Ok(_) => break,
Err(a) => cur = a,
}
}
let fl2 = Arc::clone(&fl);
let sl2 = Arc::clone(&sl);
mgr2.fsync(move || {
let covered =
sl2.load(std::sync::atomic::Ordering::SeqCst);
let mut f =
fl2.load(std::sync::atomic::Ordering::Relaxed);
while f < covered {
match fl2.compare_exchange(
f,
covered,
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::Relaxed,
) {
Ok(_) => break,
Err(a) => f = a,
}
}
Ok(())
})
.unwrap();
let fl_now = fl.load(std::sync::atomic::Ordering::SeqCst);
assert!(
fl_now >= my_lsn,
"fsync-before-commit violated: \
flushed_lsn={fl_now} < commit_lsn={my_lsn}"
);
}
}));
}
for h in handles {
h.join().unwrap();
}
}
}