use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tracing::{info, warn};
use uuid::Uuid;
use crate::api::cache::invalidation::gateway_cache_entry_matches_table_invalidation;
use crate::drivers::postgresql::sqlx_driver::{
PostgresClientRegistry, delete_rows, insert_row, update_rows,
};
use crate::parser::query_builder::condition::Condition;
#[derive(Debug, Clone)]
pub struct DeferredWriteConfig {
pub enabled: bool,
pub batch_window_ms: u64,
pub batch_max_size: usize,
pub wal_enabled: bool,
pub wal_dir: String,
pub skip_cache_invalidation: bool,
}
impl Default for DeferredWriteConfig {
fn default() -> Self {
Self {
enabled: false,
batch_window_ms: 1000,
batch_max_size: 100,
wal_enabled: true,
wal_dir: "./data/wal".to_string(),
skip_cache_invalidation: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCondition {
pub column: String,
pub value: Value,
}
impl StoredCondition {
pub fn to_condition(&self) -> Condition {
Condition::eq(self.column.clone(), self.value.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum DeferredOperation {
Insert {
payload: Value,
},
Update {
conditions: Vec<StoredCondition>,
set_payload: Value,
},
Delete {
conditions: Vec<StoredCondition>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeferredEntry {
pub id: String,
pub client_name: String,
pub table_name: String,
pub operation: DeferredOperation,
pub created_at_unix_ms: i64,
}
impl DeferredEntry {
pub fn new_insert(client_name: String, table_name: String, payload: Value) -> Self {
Self {
id: Uuid::new_v4().to_string(),
client_name,
table_name,
operation: DeferredOperation::Insert { payload },
created_at_unix_ms: chrono::Utc::now().timestamp_millis(),
}
}
pub fn new_update(
client_name: String,
table_name: String,
conditions: Vec<StoredCondition>,
set_payload: Value,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
client_name,
table_name,
operation: DeferredOperation::Update {
conditions,
set_payload,
},
created_at_unix_ms: chrono::Utc::now().timestamp_millis(),
}
}
pub fn new_delete(
client_name: String,
table_name: String,
conditions: Vec<StoredCondition>,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
client_name,
table_name,
operation: DeferredOperation::Delete { conditions },
created_at_unix_ms: chrono::Utc::now().timestamp_millis(),
}
}
}
pub struct WriteBuffer {
entries: Mutex<VecDeque<DeferredEntry>>,
pub batch_max_size: usize,
}
impl WriteBuffer {
pub fn new(batch_max_size: usize) -> Self {
Self {
entries: Mutex::new(VecDeque::new()),
batch_max_size,
}
}
pub async fn push(&self, entry: DeferredEntry) {
self.entries.lock().await.push_back(entry);
}
pub async fn drain_all(&self) -> Vec<DeferredEntry> {
let mut guard = self.entries.lock().await;
guard.drain(..).collect()
}
pub async fn len(&self) -> usize {
self.entries.lock().await.len()
}
pub async fn is_threshold_exceeded(&self) -> bool {
self.batch_max_size > 0 && self.len().await >= self.batch_max_size
}
}
pub struct WalManager {
wal_path: PathBuf,
lock: Mutex<()>,
}
impl WalManager {
pub fn new(wal_dir: &str) -> std::io::Result<Self> {
let dir = PathBuf::from(wal_dir);
std::fs::create_dir_all(&dir)?;
Ok(Self {
wal_path: dir.join("pending.wal"),
lock: Mutex::new(()),
})
}
pub async fn append(&self, entry: &DeferredEntry) -> std::io::Result<()> {
let _lock = self.lock.lock().await;
let line = serde_json::to_string(entry)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let mut file = tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&self.wal_path)
.await?;
file.write_all(line.as_bytes()).await?;
file.write_all(b"\n").await?;
file.flush().await
}
pub async fn read_pending(&self) -> std::io::Result<Vec<DeferredEntry>> {
let _lock = self.lock.lock().await;
if !self.wal_path.exists() {
return Ok(Vec::new());
}
let content = tokio::fs::read_to_string(&self.wal_path).await?;
let entries = content
.lines()
.filter(|l| !l.trim().is_empty())
.filter_map(|line| serde_json::from_str::<DeferredEntry>(line).ok())
.collect();
Ok(entries)
}
pub async fn clear(&self) -> std::io::Result<()> {
let _lock = self.lock.lock().await;
if self.wal_path.exists() {
tokio::fs::write(&self.wal_path, b"").await?;
}
Ok(())
}
pub fn wal_path(&self) -> &PathBuf {
&self.wal_path
}
}
#[derive(Debug, Default)]
pub struct FlushSummary {
pub total: usize,
pub succeeded: usize,
pub failed: usize,
pub invalidated_tables: Vec<(String, String)>, }
pub async fn flush_pending(
buffer: &WriteBuffer,
wal: Option<&WalManager>,
pg_registry: &PostgresClientRegistry,
cache: &moka::future::Cache<String, serde_json::Value>,
) -> FlushSummary {
let entries = buffer.drain_all().await;
if entries.is_empty() {
return FlushSummary::default();
}
let total = entries.len();
let mut succeeded = 0usize;
let mut failed = 0usize;
let mut dirty: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
for entry in &entries {
let pool = match pg_registry.get_pool(&entry.client_name) {
Some(p) => p,
None => {
warn!(
entry_id = %entry.id,
client = %entry.client_name,
table = %entry.table_name,
"deferred_write flush: no pool for client, skipping entry"
);
failed += 1;
continue;
}
};
let result: Result<(), String> = match &entry.operation {
DeferredOperation::Insert { payload } => insert_row(&pool, &entry.table_name, payload)
.await
.map(|_| ())
.map_err(|e| format!("{e:?}")),
DeferredOperation::Update {
conditions,
set_payload,
} => {
let conds: Vec<Condition> = conditions
.iter()
.map(StoredCondition::to_condition)
.collect();
update_rows(&pool, &entry.table_name, &conds, set_payload)
.await
.map(|_| ())
.map_err(|e| e.to_string())
}
DeferredOperation::Delete { conditions } => {
let conds: Vec<Condition> = conditions
.iter()
.map(StoredCondition::to_condition)
.collect();
delete_rows(&pool, &entry.table_name, &conds)
.await
.map(|_| ())
.map_err(|e| e.to_string())
}
};
match result {
Ok(()) => {
succeeded += 1;
dirty.insert((entry.client_name.clone(), entry.table_name.clone()));
}
Err(e) => {
warn!(
entry_id = %entry.id,
client = %entry.client_name,
table = %entry.table_name,
error = %e,
"deferred_write flush: entry failed, will be lost (not re-queued)"
);
failed += 1;
}
}
}
let table_name_clone_for_closure = dirty.clone();
let _ = cache
.invalidate_entries_if(move |key, _value| {
table_name_clone_for_closure
.iter()
.any(|(_, table)| gateway_cache_entry_matches_table_invalidation(key, table))
})
.ok();
cache.run_pending_tasks().await;
let invalidated_tables: Vec<(String, String)> = dirty.into_iter().collect();
if let Some(wal_mgr) = wal {
if let Err(e) = wal_mgr.clear().await {
warn!(error = %e, "deferred_write: failed to truncate WAL after flush");
}
}
if succeeded > 0 || failed > 0 {
info!(
succeeded,
failed,
total,
tables = ?invalidated_tables.iter().map(|(_, t)| t).collect::<Vec<_>>(),
"deferred_write flush complete"
);
}
FlushSummary {
total,
succeeded,
failed,
invalidated_tables,
}
}
pub async fn recover_from_wal(wal: &WalManager, buffer: &WriteBuffer) {
match wal.read_pending().await {
Ok(entries) if entries.is_empty() => {}
Ok(entries) => {
let count = entries.len();
for entry in entries {
buffer.push(entry).await;
}
info!(
recovered = count,
wal_path = %wal.wal_path().display(),
"deferred_write: recovered entries from WAL"
);
}
Err(e) => {
warn!(
error = %e,
"deferred_write: failed to read WAL for recovery, entries may be lost"
);
}
}
}
pub fn spawn_flush_loop(
buffer: Arc<WriteBuffer>,
wal: Option<Arc<WalManager>>,
pg_registry: Arc<PostgresClientRegistry>,
cache: Arc<moka::future::Cache<String, serde_json::Value>>,
batch_window_ms: u64,
) {
if batch_window_ms == 0 {
return;
}
tokio::spawn(async move {
let interval = std::time::Duration::from_millis(batch_window_ms);
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
ticker.tick().await;
flush_pending(&buffer, wal.as_deref(), &pg_registry, &cache).await;
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn write_buffer_push_len_and_drain_round_trip() {
let buffer = WriteBuffer::new(10);
assert_eq!(buffer.len().await, 0);
buffer
.push(DeferredEntry::new_insert(
"athena_logging".to_string(),
"bench_table".to_string(),
serde_json::json!({ "name": "row-1" }),
))
.await;
assert_eq!(buffer.len().await, 1);
let drained = buffer.drain_all().await;
assert_eq!(drained.len(), 1);
assert_eq!(buffer.len().await, 0);
}
#[tokio::test]
async fn write_buffer_threshold_respects_batch_size() {
let buffer = WriteBuffer::new(2);
assert!(!buffer.is_threshold_exceeded().await);
buffer
.push(DeferredEntry::new_insert(
"athena_logging".to_string(),
"bench_table".to_string(),
serde_json::json!({ "name": "row-1" }),
))
.await;
assert!(!buffer.is_threshold_exceeded().await);
buffer
.push(DeferredEntry::new_insert(
"athena_logging".to_string(),
"bench_table".to_string(),
serde_json::json!({ "name": "row-2" }),
))
.await;
assert!(buffer.is_threshold_exceeded().await);
}
#[tokio::test]
async fn wal_append_read_pending_and_clear() {
let wal_dir = std::env::temp_dir().join(format!("athena_dw_test_{}", Uuid::new_v4()));
let wal = WalManager::new(wal_dir.to_string_lossy().as_ref()).expect("wal create");
let entry = DeferredEntry::new_insert(
"athena_logging".to_string(),
"bench_table".to_string(),
serde_json::json!({ "name": "row-1" }),
);
wal.append(&entry).await.expect("wal append");
let pending = wal.read_pending().await.expect("wal read");
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].client_name, "athena_logging");
assert_eq!(pending[0].table_name, "bench_table");
wal.clear().await.expect("wal clear");
let after_clear = wal.read_pending().await.expect("wal read after clear");
assert!(after_clear.is_empty());
let _ = std::fs::remove_file(wal.wal_path());
let _ = std::fs::remove_dir_all(wal_dir);
}
#[tokio::test]
async fn recover_from_wal_pushes_entries_back_into_buffer() {
let wal_dir = std::env::temp_dir().join(format!("athena_dw_test_{}", Uuid::new_v4()));
let wal = WalManager::new(wal_dir.to_string_lossy().as_ref()).expect("wal create");
let buffer = WriteBuffer::new(10);
wal.append(&DeferredEntry::new_insert(
"athena_logging".to_string(),
"bench_table".to_string(),
serde_json::json!({ "name": "row-1" }),
))
.await
.expect("append 1");
wal.append(&DeferredEntry::new_insert(
"athena_logging".to_string(),
"bench_table".to_string(),
serde_json::json!({ "name": "row-2" }),
))
.await
.expect("append 2");
recover_from_wal(&wal, &buffer).await;
assert_eq!(buffer.len().await, 2);
let _ = std::fs::remove_file(wal.wal_path());
let _ = std::fs::remove_dir_all(wal_dir);
}
}