use std::collections::VecDeque;
use std::fs::{File, OpenOptions};
use std::io::{self, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
const DEFAULT_RING_SIZE: u32 = 256;
const DEFAULT_BATCH_SIZE: usize = 32;
const DEFAULT_BATCH_TIMEOUT_US: u64 = 100;
#[derive(Clone)]
pub struct CompletionToken {
#[allow(dead_code)]
id: u64,
completed: Arc<AtomicBool>,
result: Arc<AtomicU64>,
}
impl CompletionToken {
fn new(id: u64) -> Self {
Self {
id,
completed: Arc::new(AtomicBool::new(false)),
result: Arc::new(AtomicU64::new(0)),
}
}
#[inline]
pub fn is_completed(&self) -> bool {
self.completed.load(Ordering::Acquire)
}
pub fn wait(&self) -> io::Result<usize> {
while !self.is_completed() {
std::hint::spin_loop();
}
let result = self.result.load(Ordering::Acquire);
if result & (1 << 63) != 0 {
Err(io::Error::from_raw_os_error((result & 0x7FFFFFFF) as i32))
} else {
Ok(result as usize)
}
}
fn complete(&self, bytes_written: usize) {
self.result.store(bytes_written as u64, Ordering::Release);
self.completed.store(true, Ordering::Release);
}
fn fail(&self, error_code: i32) {
self.result
.store((1 << 63) | (error_code as u64), Ordering::Release);
self.completed.store(true, Ordering::Release);
}
}
struct SubmissionEntry {
data: Vec<u8>,
offset: u64,
token: CompletionToken,
}
struct BatchSubmitter {
pending: VecDeque<SubmissionEntry>,
batch_size: usize,
pending_bytes: usize,
}
impl BatchSubmitter {
fn new(batch_size: usize) -> Self {
Self {
pending: VecDeque::with_capacity(batch_size),
batch_size,
pending_bytes: 0,
}
}
fn push(&mut self, entry: SubmissionEntry) {
self.pending_bytes += entry.data.len();
self.pending.push_back(entry);
}
fn should_submit(&self) -> bool {
self.pending.len() >= self.batch_size
}
fn take_batch(&mut self) -> Vec<SubmissionEntry> {
self.pending_bytes = 0;
self.pending.drain(..).collect()
}
fn len(&self) -> usize {
self.pending.len()
}
}
#[derive(Clone)]
pub struct IoUringWalConfig {
pub ring_size: u32,
pub batch_size: usize,
pub batch_timeout_us: u64,
pub use_direct_io: bool,
pub preallocate_size: u64,
}
impl Default for IoUringWalConfig {
fn default() -> Self {
Self {
ring_size: DEFAULT_RING_SIZE,
batch_size: DEFAULT_BATCH_SIZE,
batch_timeout_us: DEFAULT_BATCH_TIMEOUT_US,
use_direct_io: false,
preallocate_size: 64 * 1024 * 1024, }
}
}
pub struct IoUringWal {
#[allow(dead_code)]
path: PathBuf,
file: File,
#[allow(dead_code)]
config: IoUringWalConfig,
submitter: parking_lot::Mutex<BatchSubmitter>,
next_op_id: AtomicU64,
write_offset: AtomicU64,
total_bytes: AtomicU64,
total_ops: AtomicU64,
shutdown: AtomicBool,
}
impl IoUringWal {
pub fn open<P: AsRef<Path>>(path: P, config: IoUringWalConfig) -> io::Result<Self> {
let path = path.as_ref().to_path_buf();
let mut options = OpenOptions::new();
options.create(true).read(true).write(true);
let mut file = options.open(&path)?;
if config.preallocate_size > 0 {
let current_len = file.metadata()?.len();
if current_len < config.preallocate_size {
file.seek(SeekFrom::Start(config.preallocate_size - 1))?;
file.write_all(&[0])?;
file.seek(SeekFrom::Start(0))?;
}
}
Ok(Self {
path,
file,
config: config.clone(),
submitter: parking_lot::Mutex::new(BatchSubmitter::new(config.batch_size)),
next_op_id: AtomicU64::new(0),
write_offset: AtomicU64::new(0),
total_bytes: AtomicU64::new(0),
total_ops: AtomicU64::new(0),
shutdown: AtomicBool::new(false),
})
}
pub fn write(&self, data: Vec<u8>) -> io::Result<CompletionToken> {
if self.shutdown.load(Ordering::Acquire) {
return Err(io::Error::new(io::ErrorKind::Other, "WAL is shutdown"));
}
let op_id = self.next_op_id.fetch_add(1, Ordering::Relaxed);
let token = CompletionToken::new(op_id);
let data_len = data.len() as u64;
let offset = self.write_offset.fetch_add(data_len, Ordering::Relaxed);
let entry = SubmissionEntry {
data,
offset,
token: token.clone(),
};
let should_submit = {
let mut submitter = self.submitter.lock();
submitter.push(entry);
submitter.should_submit()
};
if should_submit {
self.submit_batch()?;
}
Ok(token)
}
fn submit_batch(&self) -> io::Result<()> {
let entries = {
let mut submitter = self.submitter.lock();
submitter.take_batch()
};
if entries.is_empty() {
return Ok(());
}
self.submit_sync(entries)
}
fn submit_sync(&self, entries: Vec<SubmissionEntry>) -> io::Result<()> {
for entry in entries {
match self.do_write(&entry) {
Ok(bytes) => {
entry.token.complete(bytes);
self.total_bytes.fetch_add(bytes as u64, Ordering::Relaxed);
self.total_ops.fetch_add(1, Ordering::Relaxed);
}
Err(e) => {
entry.token.fail(e.raw_os_error().unwrap_or(-1));
}
}
}
Ok(())
}
fn do_write(&self, entry: &SubmissionEntry) -> io::Result<usize> {
#[cfg(unix)]
{
use std::os::unix::fs::FileExt;
self.file.write_at(&entry.data, entry.offset)
}
#[cfg(not(unix))]
{
use std::io::{Seek, SeekFrom, Write};
let mut file = &self.file;
file.seek(SeekFrom::Start(entry.offset))?;
file.write_all(&entry.data)?;
Ok(entry.data.len())
}
}
pub fn flush(&self) -> io::Result<()> {
self.submit_batch()?;
self.file.sync_all()
}
pub fn flush_pending(&self) -> io::Result<()> {
self.submit_batch()
}
pub fn stats(&self) -> WalStats {
let submitter = self.submitter.lock();
WalStats {
total_bytes_written: self.total_bytes.load(Ordering::Relaxed),
total_operations: self.total_ops.load(Ordering::Relaxed),
current_offset: self.write_offset.load(Ordering::Relaxed),
pending_entries: submitter.len(),
pending_bytes: submitter.pending_bytes,
}
}
pub fn shutdown(&self) -> io::Result<()> {
self.shutdown.store(true, Ordering::Release);
self.flush()
}
}
#[derive(Debug, Clone)]
pub struct WalStats {
pub total_bytes_written: u64,
pub total_operations: u64,
pub current_offset: u64,
pub pending_entries: usize,
pub pending_bytes: usize,
}
pub struct CompletionHandler {
tokens: Vec<CompletionToken>,
}
impl CompletionHandler {
pub fn new() -> Self {
Self { tokens: Vec::new() }
}
pub fn track(&mut self, token: CompletionToken) {
self.tokens.push(token);
}
pub fn wait_all(&self) -> io::Result<Vec<usize>> {
let mut results = Vec::with_capacity(self.tokens.len());
for token in &self.tokens {
results.push(token.wait()?);
}
Ok(results)
}
pub fn poll(&self) -> Vec<(usize, bool)> {
self.tokens
.iter()
.enumerate()
.map(|(i, t)| (i, t.is_completed()))
.collect()
}
pub fn completed_count(&self) -> usize {
self.tokens.iter().filter(|t| t.is_completed()).count()
}
pub fn all_completed(&self) -> bool {
self.tokens.iter().all(|t| t.is_completed())
}
pub fn clear(&mut self) {
self.tokens.clear();
}
}
impl Default for CompletionHandler {
fn default() -> Self {
Self::new()
}
}
pub struct GroupCommitWal {
wal: IoUringWal,
group_size: usize,
#[allow(dead_code)]
group_timeout_ms: u64,
pending: parking_lot::Mutex<Vec<(Vec<u8>, CompletionToken)>>,
}
impl GroupCommitWal {
pub fn new(wal: IoUringWal, group_size: usize, group_timeout_ms: u64) -> Self {
Self {
wal,
group_size,
group_timeout_ms,
pending: parking_lot::Mutex::new(Vec::with_capacity(group_size)),
}
}
pub fn write(&self, data: Vec<u8>) -> io::Result<CompletionToken> {
let token = self.wal.write(data)?;
let should_flush = {
let pending = self.pending.lock();
pending.len() >= self.group_size
};
if should_flush {
self.wal.flush_pending()?;
}
Ok(token)
}
pub fn flush(&self) -> io::Result<()> {
self.wal.flush()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use tempfile::tempdir;
#[test]
fn test_completion_token() {
let token = CompletionToken::new(1);
assert!(!token.is_completed());
token.complete(100);
assert!(token.is_completed());
assert_eq!(token.wait().unwrap(), 100);
}
#[test]
fn test_completion_token_error() {
let token = CompletionToken::new(1);
token.fail(5);
assert!(token.is_completed());
assert!(token.wait().is_err());
}
#[test]
fn test_wal_basic() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let config = IoUringWalConfig {
batch_size: 4,
preallocate_size: 1024 * 1024,
..Default::default()
};
let wal = IoUringWal::open(&wal_path, config).unwrap();
let token = wal.write(b"hello".to_vec()).unwrap();
wal.flush().unwrap();
assert_eq!(token.wait().unwrap(), 5);
let stats = wal.stats();
assert_eq!(stats.total_bytes_written, 5);
assert_eq!(stats.total_operations, 1);
}
#[test]
fn test_wal_batch() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let config = IoUringWalConfig {
batch_size: 4,
..Default::default()
};
let wal = IoUringWal::open(&wal_path, config).unwrap();
let mut handler = CompletionHandler::new();
for i in 0..10 {
let token = wal.write(format!("entry{}", i).into_bytes()).unwrap();
handler.track(token);
}
wal.flush().unwrap();
assert!(handler.all_completed());
assert_eq!(handler.completed_count(), 10);
}
#[test]
fn test_wal_concurrent() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let config = IoUringWalConfig {
batch_size: 1,
..Default::default()
};
let wal = Arc::new(IoUringWal::open(&wal_path, config).unwrap());
let mut handles = vec![];
for t in 0..4 {
let wal = wal.clone();
handles.push(thread::spawn(move || {
for i in 0..100 {
let data = format!("thread{}:entry{}", t, i);
let token = wal.write(data.into_bytes()).unwrap();
token.wait().unwrap();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
wal.flush().unwrap();
let stats = wal.stats();
assert_eq!(stats.total_operations, 400);
}
#[test]
fn test_completion_handler() {
let mut handler = CompletionHandler::new();
let t1 = CompletionToken::new(1);
let t2 = CompletionToken::new(2);
let t3 = CompletionToken::new(3);
handler.track(t1.clone());
handler.track(t2.clone());
handler.track(t3.clone());
assert_eq!(handler.completed_count(), 0);
t1.complete(10);
assert_eq!(handler.completed_count(), 1);
t2.complete(20);
t3.complete(30);
assert!(handler.all_completed());
let results = handler.wait_all().unwrap();
assert_eq!(results, vec![10, 20, 30]);
}
#[test]
fn test_group_commit() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let wal = IoUringWal::open(&wal_path, IoUringWalConfig::default()).unwrap();
let group_wal = GroupCommitWal::new(wal, 10, 100);
let mut tokens = vec![];
for i in 0..25 {
tokens.push(group_wal.write(format!("entry{}", i).into_bytes()).unwrap());
}
group_wal.flush().unwrap();
for token in tokens {
assert!(token.is_completed());
}
}
}