use crate::error::{AmateRSError, ErrorContext, Result};
use crate::storage::wal::{WalEntry, WalEntryType};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::thread;
#[derive(Debug, Clone)]
pub struct UringWalConfig {
pub ring_size: u32,
pub batch_size: usize,
pub direct_io: bool,
pub channel_capacity: usize,
}
impl Default for UringWalConfig {
fn default() -> Self {
Self {
ring_size: 128,
batch_size: 64 * 1024, direct_io: false,
channel_capacity: 1024,
}
}
}
struct UringWriteRequest {
payload: Vec<u8>,
done_tx: tokio::sync::oneshot::Sender<std::io::Result<()>>,
}
enum UringControl {
Write(UringWriteRequest),
Flush {
done_tx: tokio::sync::oneshot::Sender<std::io::Result<()>>,
},
Shutdown,
}
#[derive(Clone, Debug)]
pub struct UringWalWriter {
tx: Arc<tokio::sync::mpsc::UnboundedSender<UringControl>>,
thread_handle: Arc<Mutex<Option<thread::JoinHandle<()>>>>,
config: UringWalConfig,
path: PathBuf,
}
impl UringWalWriter {
pub fn open(path: impl AsRef<Path>, config: UringWalConfig) -> Result<Self> {
let path = path.as_ref().to_path_buf();
if config.direct_io {
tracing::warn!(
path = %path.display(),
"UringWalWriter: direct_io = true is not yet implemented; \
falling back to buffered I/O (O_DIRECT omitted). \
This will be fixed in a future release once aligned-buffer \
allocation is integrated."
);
}
if config.ring_size == 0 || !config.ring_size.is_power_of_two() {
return Err(AmateRSError::ConfigError(ErrorContext::new(format!(
"UringWalConfig.ring_size must be a non-zero power of two, got {}",
config.ring_size
))));
}
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<UringControl>();
let thread_path = path.clone();
let thread_config = config.clone();
let join_handle = thread::Builder::new()
.name("amaters-uring-wal".to_owned())
.spawn(move || {
run_uring_thread(thread_path, thread_config, rx);
})
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"Failed to spawn io_uring WAL thread: {}",
e
)))
})?;
Ok(Self {
tx: Arc::new(tx),
thread_handle: Arc::new(Mutex::new(Some(join_handle))),
config,
path,
})
}
pub async fn append(&self, entry: &WalEntry) -> Result<()> {
let payload = serialize_entry(entry);
let (done_tx, done_rx) = tokio::sync::oneshot::channel();
self.tx
.send(UringControl::Write(UringWriteRequest { payload, done_tx }))
.map_err(|_| {
AmateRSError::IoError(ErrorContext::new(
"io_uring WAL thread has shut down; cannot append entry",
))
})?;
done_rx
.await
.map_err(|_| {
AmateRSError::IoError(ErrorContext::new(
"io_uring WAL thread dropped the oneshot sender unexpectedly",
))
})?
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"io_uring write_at failed: {}",
e
)))
})
}
pub async fn flush(&self) -> Result<()> {
let (done_tx, done_rx) = tokio::sync::oneshot::channel();
self.tx.send(UringControl::Flush { done_tx }).map_err(|_| {
AmateRSError::IoError(ErrorContext::new(
"io_uring WAL thread has shut down; cannot flush",
))
})?;
done_rx
.await
.map_err(|_| {
AmateRSError::IoError(ErrorContext::new(
"io_uring WAL thread dropped the flush oneshot sender",
))
})?
.map_err(|e| {
AmateRSError::IoError(ErrorContext::new(format!(
"io_uring sync_all failed: {}",
e
)))
})
}
pub fn close(self) -> Result<()> {
let _ = self.tx.send(UringControl::Shutdown);
let mut guard = self.thread_handle.lock().expect(
"UringWalWriter: thread_handle mutex was poisoned; cannot join io_uring thread",
);
if let Some(handle) = guard.take() {
handle.join().map_err(|_| {
AmateRSError::IoError(ErrorContext::new(
"io_uring WAL thread panicked during join",
))
})?;
}
Ok(())
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn config(&self) -> &UringWalConfig {
&self.config
}
}
fn serialize_entry(entry: &WalEntry) -> Vec<u8> {
let mut inner = Vec::with_capacity(64);
inner.extend_from_slice(&0x57414Cu32.to_le_bytes());
inner.extend_from_slice(&entry.sequence.to_le_bytes());
let type_byte: u8 = match entry.entry_type {
WalEntryType::Put => 1,
WalEntryType::Delete => 2,
};
inner.push(type_byte);
inner.extend_from_slice(&(entry.key.len() as u32).to_le_bytes());
inner.extend_from_slice(entry.key.as_bytes());
if let Some(ref value) = entry.value {
inner.extend_from_slice(&(value.len() as u32).to_le_bytes());
inner.extend_from_slice(value.as_bytes());
} else {
inner.extend_from_slice(&0u32.to_le_bytes());
}
inner.extend_from_slice(&entry.checksum.to_le_bytes());
let mut record = Vec::with_capacity(4 + inner.len());
record.extend_from_slice(&(inner.len() as u32).to_le_bytes());
record.extend_from_slice(&inner);
record
}
fn run_uring_thread(
path: PathBuf,
config: UringWalConfig,
rx: tokio::sync::mpsc::UnboundedReceiver<UringControl>,
) {
tokio_uring::builder()
.entries(config.ring_size)
.start(async move {
if let Err(e) = uring_event_loop(path, config, rx).await {
tracing::error!("io_uring WAL thread exited with error: {}", e);
}
});
}
async fn uring_event_loop(
path: PathBuf,
config: UringWalConfig,
mut rx: tokio::sync::mpsc::UnboundedReceiver<UringControl>,
) -> std::io::Result<()> {
let file = tokio_uring::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.await?;
let mut write_offset: u64 = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
tracing::debug!(
path = %path.display(),
write_offset,
"io_uring WAL writer started"
);
loop {
let first = match rx.recv().await {
Some(ctrl) => ctrl,
None => {
tracing::debug!("io_uring WAL channel closed; thread exiting");
break;
}
};
match first {
UringControl::Shutdown => {
tracing::debug!("io_uring WAL received Shutdown; thread exiting");
break;
}
UringControl::Flush { done_tx } => {
let result = file.sync_all().await;
let _ = done_tx.send(result);
}
UringControl::Write(first_req) => {
let mut batch: Vec<UringWriteRequest> = Vec::new();
let mut batch_bytes: usize = first_req.payload.len();
batch.push(first_req);
while batch_bytes < config.batch_size {
match rx.try_recv() {
Ok(UringControl::Write(req)) => {
batch_bytes += req.payload.len();
batch.push(req);
}
Ok(UringControl::Flush { done_tx }) => {
write_offset = submit_batch(&file, batch, write_offset).await;
let result = file.sync_all().await;
let _ = done_tx.send(result);
batch = Vec::new();
batch_bytes = 0;
break;
}
Ok(UringControl::Shutdown) => {
let _ = submit_batch(&file, batch, write_offset).await;
tracing::debug!("io_uring WAL received Shutdown mid-batch; exiting");
return Ok(());
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
break;
}
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
let _ = submit_batch(&file, batch, write_offset).await;
tracing::debug!("io_uring WAL channel disconnected mid-batch; exiting");
return Ok(());
}
}
}
if !batch.is_empty() {
write_offset = submit_batch(&file, batch, write_offset).await;
}
}
}
}
tracing::debug!(path = %path.display(), "io_uring WAL file closed");
Ok(())
}
async fn submit_batch(
file: &tokio_uring::fs::File,
batch: Vec<UringWriteRequest>,
mut offset: u64,
) -> u64 {
for req in batch {
let payload_len = req.payload.len() as u64;
let buf = req.payload;
let result = write_all_at(file, buf, offset).await;
match result {
Ok(bytes_written) => {
offset += bytes_written as u64;
let _ = req.done_tx.send(Ok(()));
}
Err(e) => {
let _ = req.done_tx.send(Err(e));
offset += payload_len;
}
}
}
offset
}
async fn write_all_at(
file: &tokio_uring::fs::File,
mut buf: Vec<u8>,
mut offset: u64,
) -> std::io::Result<usize> {
let total = buf.len();
let mut written = 0usize;
while written < total {
let slice = buf[written..].to_vec();
let (result, _returned_buf) = file.write_at(slice, offset + written as u64).submit().await;
match result {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"io_uring write_at returned 0 bytes written",
));
}
Ok(n) => {
written += n;
}
Err(e) => {
return Err(e);
}
}
}
Ok(written)
}
#[cfg(all(test, target_os = "linux", feature = "io-uring"))]
mod tests {
use super::*;
use crate::storage::wal::{Wal, WalReader};
use crate::types::{CipherBlob, Key};
use tempfile::TempDir;
fn make_put_entry(seq: u64, key_suffix: &str) -> WalEntry {
let key = Key::from_str(&format!("test_key_{}", key_suffix));
let value = CipherBlob::new(format!("test_value_{}", key_suffix).into_bytes());
WalEntry::put(seq, key, value)
}
fn make_delete_entry(seq: u64, key_suffix: &str) -> WalEntry {
let key = Key::from_str(&format!("del_key_{}", key_suffix));
WalEntry::delete(seq, key)
}
fn read_all_entries(path: &Path) -> Vec<WalEntry> {
let mut reader = WalReader::open(path).expect("WalReader::open failed");
let mut entries = Vec::new();
loop {
match reader.read_entry() {
Ok(Some(e)) => entries.push(e),
Ok(None) => break,
Err(e) => panic!("WalReader::read_entry error: {}", e),
}
}
entries
}
#[test]
fn test_uring_wal_open_and_close() {
let tmp = TempDir::new().expect("TempDir::new failed");
let wal_path = tmp.path().join("test_open_close.wal");
let writer =
UringWalWriter::open(&wal_path, UringWalConfig::default()).expect("open failed");
assert!(
writer.path() == wal_path,
"path() should return the path passed to open()"
);
writer.close().expect("close failed");
assert!(
wal_path.exists(),
"WAL file should exist after open+close: {}",
wal_path.display()
);
}
#[test]
fn test_uring_wal_append_and_read_back() {
let tmp = TempDir::new().expect("TempDir::new failed");
let wal_path = tmp.path().join("test_append.wal");
let rt = tokio::runtime::Runtime::new().expect("tokio Runtime::new failed");
let writer =
UringWalWriter::open(&wal_path, UringWalConfig::default()).expect("open failed");
let entries = vec![
make_put_entry(1, "alpha"),
make_put_entry(2, "beta"),
make_delete_entry(3, "gamma"),
];
rt.block_on(async {
for e in &entries {
writer.append(e).await.expect("append failed");
}
writer.flush().await.expect("flush failed");
});
writer.close().expect("close failed");
let recovered = read_all_entries(&wal_path);
assert_eq!(
recovered.len(),
entries.len(),
"recovered entry count mismatch"
);
for (original, recovered_entry) in entries.iter().zip(recovered.iter()) {
assert_eq!(
original.sequence, recovered_entry.sequence,
"sequence mismatch"
);
assert_eq!(
original.entry_type, recovered_entry.entry_type,
"entry_type mismatch"
);
assert_eq!(
original.key.as_bytes(),
recovered_entry.key.as_bytes(),
"key mismatch"
);
assert_eq!(
original.checksum, recovered_entry.checksum,
"checksum mismatch"
);
}
}
#[test]
fn test_uring_wal_flush() {
let tmp = TempDir::new().expect("TempDir::new failed");
let wal_path = tmp.path().join("test_flush.wal");
let rt = tokio::runtime::Runtime::new().expect("tokio Runtime::new failed");
let writer = UringWalWriter::open(&wal_path, UringWalConfig::default())
.expect("UringWalWriter::open failed");
let entry = make_put_entry(42, "flush_test");
rt.block_on(async {
writer.append(&entry).await.expect("append failed");
writer.flush().await.expect("flush failed");
});
writer.close().expect("close failed");
let recovered = read_all_entries(&wal_path);
assert_eq!(recovered.len(), 1, "expected exactly 1 recovered entry");
assert_eq!(recovered[0].sequence, 42);
assert_eq!(
recovered[0].key.as_bytes(),
b"test_key_flush_test",
"key bytes mismatch"
);
}
#[test]
fn test_uring_wal_batch_writes() {
let tmp = TempDir::new().expect("TempDir::new failed");
let wal_path = tmp.path().join("test_batch.wal");
let rt = tokio::runtime::Runtime::new().expect("tokio Runtime::new failed");
let writer = Arc::new(
UringWalWriter::open(&wal_path, UringWalConfig::default())
.expect("UringWalWriter::open failed"),
);
const NUM_ENTRIES: u64 = 100;
let entries: Vec<WalEntry> = (0..NUM_ENTRIES)
.map(|i| make_put_entry(i, &i.to_string()))
.collect();
rt.block_on(async {
let futs: Vec<_> = entries
.iter()
.map(|e| {
let w = Arc::clone(&writer);
async move { w.append(e).await }
})
.collect();
let results = futures::future::join_all(futs).await;
for (i, res) in results.into_iter().enumerate() {
res.unwrap_or_else(|err| panic!("append[{}] failed: {}", i, err));
}
writer.flush().await.expect("flush failed");
});
let writer =
Arc::try_unwrap(writer).expect("Arc should have only one reference at this point");
writer.close().expect("close failed");
let recovered = read_all_entries(&wal_path);
assert_eq!(
recovered.len(),
NUM_ENTRIES as usize,
"expected {} recovered entries, got {}",
NUM_ENTRIES,
recovered.len()
);
let mut seen_seqs: Vec<u64> = recovered.iter().map(|e| e.sequence).collect();
seen_seqs.sort_unstable();
let expected: Vec<u64> = (0..NUM_ENTRIES).collect();
assert_eq!(seen_seqs, expected, "sequence numbers mismatch");
}
}