use crate::store_utils::{
DEFAULT_TIMEOUT, delete_with_timeout, get_with_timeout, list_with_timeout, put_with_timeout,
};
use anyhow::Result;
use metrics;
use object_store::ObjectStore;
use object_store::path::Path;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use tracing::{debug, info, instrument, warn};
use uni_common::Properties;
use uni_common::core::id::{Eid, Vid};
use uni_common::sync::acquire_mutex;
use uuid::Uuid;
fn parse_lsn_from_filename(path: &Path) -> Option<u64> {
let filename = path.filename()?;
if filename.len() < 20 {
return None;
}
filename[..20].parse::<u64>().ok()
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum Mutation {
InsertEdge {
src_vid: Vid,
dst_vid: Vid,
edge_type: u32,
eid: Eid,
version: u64,
properties: Properties,
#[serde(default)]
edge_type_name: Option<String>,
},
DeleteEdge {
eid: Eid,
src_vid: Vid,
dst_vid: Vid,
edge_type: u32,
version: u64,
},
InsertVertex {
vid: Vid,
properties: Properties,
#[serde(default)]
labels: Vec<String>,
},
DeleteVertex {
vid: Vid,
#[serde(default)]
labels: Vec<String>,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WalSegment {
pub lsn: u64,
pub mutations: Vec<Mutation>,
}
pub struct WriteAheadLog {
store: Arc<dyn ObjectStore>,
prefix: Path,
state: Mutex<WalState>,
}
struct WalState {
buffer: Vec<Mutation>,
next_lsn: u64,
flushed_lsn: u64,
}
impl WriteAheadLog {
pub fn new(store: Arc<dyn ObjectStore>, prefix: Path) -> Self {
Self {
store,
prefix,
state: Mutex::new(WalState {
buffer: Vec::new(),
next_lsn: 1, flushed_lsn: 0,
}),
}
}
pub async fn initialize(&self) -> Result<u64> {
let max_lsn = self.find_max_lsn().await?;
{
let mut state = acquire_mutex(&self.state, "wal_state")?;
state.next_lsn = max_lsn + 1;
state.flushed_lsn = max_lsn;
}
Ok(max_lsn)
}
async fn find_max_lsn(&self) -> Result<u64> {
let metas = list_with_timeout(&self.store, Some(&self.prefix), DEFAULT_TIMEOUT).await?;
let mut max_lsn: u64 = 0;
for meta in metas {
if let Some(lsn) = parse_lsn_from_filename(&meta.location) {
max_lsn = max_lsn.max(lsn);
} else {
warn!(
path = %meta.location,
"WAL filename doesn't match expected format, downloading segment"
);
let get_result =
get_with_timeout(&self.store, &meta.location, DEFAULT_TIMEOUT).await?;
let bytes = get_result.bytes().await?;
if bytes.is_empty() {
continue;
}
let segment: WalSegment = serde_json::from_slice(&bytes)?;
max_lsn = max_lsn.max(segment.lsn);
}
}
Ok(max_lsn)
}
#[instrument(skip(self, mutation), level = "trace")]
pub fn append(&self, mutation: &Mutation) -> Result<()> {
let mut state = acquire_mutex(&self.state, "wal_state")?;
state.buffer.push(mutation.clone());
metrics::counter!("uni_wal_entries_total").increment(1);
Ok(())
}
#[instrument(skip(self), fields(lsn, mutations_count, size_bytes))]
pub async fn flush(&self) -> Result<u64> {
let start = std::time::Instant::now();
let (batch, lsn) = {
let mut state = acquire_mutex(&self.state, "wal_state")?;
if state.buffer.is_empty() {
return Ok(state.flushed_lsn);
}
let lsn = state.next_lsn;
state.next_lsn += 1;
(std::mem::take(&mut state.buffer), lsn)
};
tracing::Span::current().record("lsn", lsn);
tracing::Span::current().record("mutations_count", batch.len());
let segment = WalSegment {
lsn,
mutations: batch.clone(),
};
let json = match serde_json::to_vec(&segment) {
Ok(j) => j,
Err(e) => {
warn!(lsn, error = %e, "Failed to serialize WAL segment, restoring buffer");
let mut state = acquire_mutex(&self.state, "wal_state")?;
let new_mutations = std::mem::take(&mut state.buffer);
state.buffer = batch;
state.buffer.extend(new_mutations);
return Err(e.into());
}
};
tracing::Span::current().record("size_bytes", json.len());
metrics::counter!("uni_wal_bytes_written_total").increment(json.len() as u64);
let filename = format!("{:020}_{}.wal", lsn, Uuid::new_v4());
let path = self.prefix.child(filename);
if let Err(e) = put_with_timeout(&self.store, &path, json.into(), DEFAULT_TIMEOUT).await {
warn!(
lsn,
error = %e,
"Failed to flush WAL segment, restoring buffer (LSN gap preserved for monotonicity)"
);
let mut state = acquire_mutex(&self.state, "wal_state")?;
let new_mutations = std::mem::take(&mut state.buffer);
state.buffer = batch;
state.buffer.extend(new_mutations);
return Err(e);
}
{
let mut state = acquire_mutex(&self.state, "wal_state")?;
state.flushed_lsn = lsn;
}
let duration = start.elapsed();
metrics::histogram!("wal_flush_latency_ms").record(duration.as_millis() as f64);
metrics::histogram!("uni_wal_flush_duration_seconds").record(duration.as_secs_f64());
if duration.as_millis() > 100 {
warn!(
lsn,
duration_ms = duration.as_millis(),
"Slow WAL flush detected"
);
} else {
debug!(
lsn,
duration_ms = duration.as_millis(),
"WAL flush completed"
);
}
Ok(lsn)
}
pub fn flushed_lsn(&self) -> Result<u64, uni_common::sync::LockPoisonedError> {
let guard = uni_common::sync::acquire_mutex(&self.state, "wal_state")?;
Ok(guard.flushed_lsn)
}
#[instrument(skip(self), level = "debug")]
pub async fn replay_since(&self, high_water_mark: u64) -> Result<Vec<Mutation>> {
let start = std::time::Instant::now();
debug!(high_water_mark, "Replaying WAL segments");
let metas = list_with_timeout(&self.store, Some(&self.prefix), DEFAULT_TIMEOUT).await?;
let mut mutations = Vec::new();
let mut paths: Vec<_> = metas.into_iter().map(|m| m.location).collect();
paths.sort();
let mut segments_replayed = 0;
for path in paths {
if let Some(lsn) = parse_lsn_from_filename(&path)
&& lsn <= high_water_mark
{
continue; }
let get_result = get_with_timeout(&self.store, &path, DEFAULT_TIMEOUT).await?;
let bytes = get_result.bytes().await?;
if bytes.is_empty() {
continue;
}
let segment: WalSegment = serde_json::from_slice(&bytes)?;
if segment.lsn > high_water_mark {
mutations.extend(segment.mutations);
segments_replayed += 1;
}
}
info!(
segments_replayed,
mutations_count = mutations.len(),
"WAL replay completed"
);
metrics::histogram!("uni_wal_replay_duration_seconds")
.record(start.elapsed().as_secs_f64());
Ok(mutations)
}
pub async fn replay(&self) -> Result<Vec<Mutation>> {
self.replay_since(0).await
}
#[instrument(skip(self), level = "info")]
pub async fn truncate_before(&self, high_water_mark: u64) -> Result<()> {
info!(high_water_mark, "Truncating WAL segments");
let metas = list_with_timeout(&self.store, Some(&self.prefix), DEFAULT_TIMEOUT).await?;
let mut deleted_count = 0;
for meta in metas {
let should_delete = if let Some(lsn) = parse_lsn_from_filename(&meta.location) {
lsn <= high_water_mark
} else {
warn!(
path = %meta.location,
"WAL filename doesn't match expected format, downloading segment"
);
let get_result =
get_with_timeout(&self.store, &meta.location, DEFAULT_TIMEOUT).await?;
let bytes = get_result.bytes().await?;
if bytes.is_empty() {
true
} else {
let segment: WalSegment = serde_json::from_slice(&bytes)?;
segment.lsn <= high_water_mark
}
};
if should_delete {
delete_with_timeout(&self.store, &meta.location, DEFAULT_TIMEOUT).await?;
deleted_count += 1;
}
}
info!(deleted_count, "WAL truncation completed");
Ok(())
}
pub async fn has_segments(&self) -> Result<bool> {
let metas = list_with_timeout(&self.store, Some(&self.prefix), DEFAULT_TIMEOUT).await?;
Ok(!metas.is_empty())
}
pub async fn truncate(&self) -> Result<()> {
info!("Truncating all WAL segments");
let metas = list_with_timeout(&self.store, Some(&self.prefix), DEFAULT_TIMEOUT).await?;
let mut deleted_count = 0;
for meta in metas {
delete_with_timeout(&self.store, &meta.location, DEFAULT_TIMEOUT).await?;
deleted_count += 1;
}
info!(deleted_count, "Full WAL truncation completed");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use object_store::local::LocalFileSystem;
use std::collections::HashMap;
use tempfile::tempdir;
#[tokio::test]
async fn test_wal_append_replay() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store, prefix);
let mutation = Mutation::InsertVertex {
vid: Vid::new(1),
properties: HashMap::new(),
labels: vec![],
};
wal.append(&mutation.clone())?;
wal.flush().await?;
let mutations = wal.replay().await?;
assert_eq!(mutations.len(), 1);
if let Mutation::InsertVertex { vid, .. } = &mutations[0] {
assert_eq!(vid.as_u64(), Vid::new(1).as_u64());
} else {
panic!("Wrong mutation type");
}
wal.truncate().await?;
let mutations2 = wal.replay().await?;
assert_eq!(mutations2.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_lsn_monotonicity() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store, prefix);
let mutation1 = Mutation::InsertVertex {
vid: Vid::new(1),
properties: HashMap::new(),
labels: vec![],
};
let mutation2 = Mutation::InsertVertex {
vid: Vid::new(2),
properties: HashMap::new(),
labels: vec![],
};
let mutation3 = Mutation::InsertVertex {
vid: Vid::new(3),
properties: HashMap::new(),
labels: vec![],
};
wal.append(&mutation1)?;
let lsn1 = wal.flush().await?;
wal.append(&mutation2)?;
let lsn2 = wal.flush().await?;
wal.append(&mutation3)?;
let lsn3 = wal.flush().await?;
assert!(lsn2 > lsn1, "LSN2 ({}) should be > LSN1 ({})", lsn2, lsn1);
assert!(lsn3 > lsn2, "LSN3 ({}) should be > LSN2 ({})", lsn3, lsn2);
assert_eq!(lsn2, lsn1 + 1);
assert_eq!(lsn3, lsn2 + 1);
Ok(())
}
#[test]
fn test_parse_lsn_from_filename() {
let path = Path::from("00000000000000000042_a1b2c3d4.wal");
assert_eq!(parse_lsn_from_filename(&path), Some(42));
let path = Path::from("00000000000000001234_e5f6a7b8.wal");
assert_eq!(parse_lsn_from_filename(&path), Some(1234));
let path = Path::from("00000000000000000001_xyz.wal");
assert_eq!(parse_lsn_from_filename(&path), Some(1));
let path = Path::from("12345678901234567890_uuid.wal");
assert_eq!(parse_lsn_from_filename(&path), Some(12345678901234567890));
let path = Path::from("invalid.wal");
assert_eq!(parse_lsn_from_filename(&path), None);
let path = Path::from("123.wal"); assert_eq!(parse_lsn_from_filename(&path), None);
let path = Path::from("abcdefghijklmnopqrst_uuid.wal"); assert_eq!(parse_lsn_from_filename(&path), None);
let path = Path::from("00000000000000000100.wal");
assert_eq!(parse_lsn_from_filename(&path), Some(100));
let path = Path::from("");
assert_eq!(parse_lsn_from_filename(&path), None);
}
#[tokio::test]
async fn test_find_max_lsn_scalability() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store, prefix);
for i in 1..=100 {
let mutation = Mutation::InsertVertex {
vid: Vid::new(i),
properties: HashMap::new(),
labels: vec![],
};
wal.append(&mutation)?;
wal.flush().await?;
}
let start = std::time::Instant::now();
let max_lsn = wal.find_max_lsn().await?;
let duration = start.elapsed();
assert_eq!(max_lsn, 100, "Max LSN should be 100");
assert!(
duration.as_millis() < 1000,
"find_max_lsn took {}ms, expected < 1000ms (filename parsing should be fast)",
duration.as_millis()
);
Ok(())
}
#[tokio::test]
async fn test_lsn_gaps_preserved_on_flush_failure() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store.clone(), prefix.clone());
wal.append(&Mutation::InsertVertex {
vid: Vid::new(1),
properties: HashMap::new(),
labels: vec![],
})?;
let lsn1 = wal.flush().await?;
assert_eq!(lsn1, 1);
wal.append(&Mutation::InsertVertex {
vid: Vid::new(2),
properties: HashMap::new(),
labels: vec![],
})?;
let lsn2 = wal.flush().await?;
assert_eq!(lsn2, 2);
wal.append(&Mutation::InsertVertex {
vid: Vid::new(3),
properties: HashMap::new(),
labels: vec![],
})?;
wal.append(&Mutation::InsertVertex {
vid: Vid::new(4),
properties: HashMap::new(),
labels: vec![],
})?;
let lsn4 = wal.flush().await?;
assert_eq!(lsn4, 3, "LSN should increment monotonically");
let mutations = wal.replay().await?;
assert_eq!(mutations.len(), 4, "All 4 mutations should be replayed");
Ok(())
}
#[tokio::test]
async fn test_lsn_watermark_no_reuse() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store, prefix);
let mut seen_lsns = std::collections::HashSet::new();
for i in 1..=50 {
wal.append(&Mutation::InsertVertex {
vid: Vid::new(i),
properties: HashMap::new(),
labels: vec![],
})?;
let lsn = wal.flush().await?;
assert!(
!seen_lsns.contains(&lsn),
"LSN {} was reused! This violates monotonicity.",
lsn
);
seen_lsns.insert(lsn);
assert_eq!(lsn, i, "LSN should be {}, got {}", i, lsn);
}
Ok(())
}
#[tokio::test]
async fn test_truncate_scalability() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store, prefix);
for i in 1..=100 {
let mutation = Mutation::InsertVertex {
vid: Vid::new(i),
properties: HashMap::new(),
labels: vec![],
};
wal.append(&mutation)?;
wal.flush().await?;
}
let start = std::time::Instant::now();
wal.truncate_before(50).await?;
let duration = start.elapsed();
let mutations = wal.replay().await?;
assert_eq!(
mutations.len(),
50,
"Should have 50 mutations remaining (51-100)"
);
assert!(
duration.as_millis() < 1000,
"truncate_before took {}ms, expected < 1000ms (filename parsing should be fast)",
duration.as_millis()
);
Ok(())
}
#[tokio::test]
async fn test_replay_since_skips_old_segments() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = WriteAheadLog::new(store, prefix);
for i in 1..=100 {
let mutation = Mutation::InsertVertex {
vid: Vid::new(i),
properties: HashMap::new(),
labels: vec![],
};
wal.append(&mutation)?;
wal.flush().await?;
}
let start = std::time::Instant::now();
let mutations = wal.replay_since(90).await?;
let duration = start.elapsed();
assert_eq!(mutations.len(), 10, "Should replay only LSNs 91-100");
assert!(
duration.as_millis() < 500,
"replay_since took {}ms, expected < 500ms (should skip by filename)",
duration.as_millis()
);
Ok(())
}
#[tokio::test]
async fn test_wal_replay_preserves_vertex_labels() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = Arc::new(WriteAheadLog::new(store, prefix));
wal.append(&Mutation::InsertVertex {
vid: Vid::new(42),
properties: {
let mut props = HashMap::new();
props.insert(
"name".to_string(),
uni_common::Value::String("Alice".to_string()),
);
props
},
labels: vec!["Person".to_string(), "User".to_string()],
})?;
wal.flush().await?;
let mutations = wal.replay().await?;
assert_eq!(mutations.len(), 1);
if let Mutation::InsertVertex { vid, labels, .. } = &mutations[0] {
assert_eq!(vid.as_u64(), 42);
assert_eq!(labels.len(), 2);
assert!(labels.contains(&"Person".to_string()));
assert!(labels.contains(&"User".to_string()));
} else {
panic!("Expected InsertVertex mutation");
}
Ok(())
}
#[tokio::test]
async fn test_wal_replay_preserves_delete_vertex_labels() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = Arc::new(WriteAheadLog::new(store, prefix));
wal.append(&Mutation::DeleteVertex {
vid: Vid::new(99),
labels: vec!["Person".to_string(), "Admin".to_string()],
})?;
wal.flush().await?;
let mutations = wal.replay().await?;
assert_eq!(mutations.len(), 1);
if let Mutation::DeleteVertex { vid, labels } = &mutations[0] {
assert_eq!(vid.as_u64(), 99);
assert_eq!(labels.len(), 2);
assert!(labels.contains(&"Person".to_string()));
assert!(labels.contains(&"Admin".to_string()));
} else {
panic!("Expected DeleteVertex mutation");
}
Ok(())
}
#[tokio::test]
async fn test_wal_replay_preserves_edge_type_name() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let wal = Arc::new(WriteAheadLog::new(store, prefix));
wal.append(&Mutation::InsertEdge {
src_vid: Vid::new(1),
dst_vid: Vid::new(2),
edge_type: 100,
eid: Eid::new(500),
version: 1,
properties: {
let mut props = HashMap::new();
props.insert("since".to_string(), uni_common::Value::Int(2020));
props
},
edge_type_name: Some("KNOWS".to_string()),
})?;
wal.flush().await?;
let mutations = wal.replay().await?;
assert_eq!(mutations.len(), 1);
if let Mutation::InsertEdge {
eid,
edge_type_name,
..
} = &mutations[0]
{
assert_eq!(eid.as_u64(), 500);
assert_eq!(edge_type_name.as_deref(), Some("KNOWS"));
} else {
panic!("Expected InsertEdge mutation");
}
Ok(())
}
#[tokio::test]
async fn test_wal_backward_compatibility_labels() -> Result<()> {
let dir = tempdir()?;
let store = Arc::new(LocalFileSystem::new_with_prefix(dir.path())?);
let prefix = Path::from("wal");
let old_format_json = r#"{
"lsn": 1,
"mutations": [
{
"InsertVertex": {
"vid": 123,
"properties": {}
}
}
]
}"#;
let path = prefix.child("00000000000000000001_test.wal");
store.put(&path, old_format_json.into()).await?;
let wal = WriteAheadLog::new(store, prefix);
let mutations = wal.replay().await?;
assert_eq!(mutations.len(), 1);
if let Mutation::InsertVertex { vid, labels, .. } = &mutations[0] {
assert_eq!(vid.as_u64(), 123);
assert_eq!(
labels.len(),
0,
"Old format should deserialize with empty labels"
);
} else {
panic!("Expected InsertVertex mutation");
}
Ok(())
}
}