use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc::{self, Receiver, SyncSender, TrySendError};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use super::config::IngestionSafetyLevel;
#[derive(Debug, Clone)]
pub enum WriteOp {
Insert {
table: String,
row_id: u64,
data: Vec<u8>,
},
Update {
table: String,
row_id: u64,
data: Vec<u8>,
},
Delete {
table: String,
row_id: u64,
},
}
impl WriteOp {
pub fn table(&self) -> &str {
match self {
Self::Insert { table, .. } => table,
Self::Update { table, .. } => table,
Self::Delete { table, .. } => table,
}
}
pub fn row_id(&self) -> u64 {
match self {
Self::Insert { row_id, .. } => *row_id,
Self::Update { row_id, .. } => *row_id,
Self::Delete { row_id, .. } => *row_id,
}
}
pub fn key(&self) -> Vec<u8> {
format!("data:{}:{}", self.table(), self.row_id()).into_bytes()
}
pub fn data(&self) -> Option<&[u8]> {
match self {
Self::Insert { data, .. } => Some(data),
Self::Update { data, .. } => Some(data),
Self::Delete { .. } => None,
}
}
pub fn size(&self) -> usize {
match self {
Self::Insert { table, data, .. } => table.len() + 8 + data.len(),
Self::Update { table, data, .. } => table.len() + 8 + data.len(),
Self::Delete { table, .. } => table.len() + 8,
}
}
}
#[derive(Debug)]
pub struct TransactionBuffer {
pub txn_id: u64,
pub read_timestamp: u64,
operations: Vec<WriteOp>,
size: usize,
read_set: HashMap<String, Vec<u64>>,
read_only: bool,
}
impl TransactionBuffer {
pub fn new(txn_id: u64, read_timestamp: u64) -> Self {
Self {
txn_id,
read_timestamp,
operations: Vec::with_capacity(64),
size: 0,
read_set: HashMap::new(),
read_only: true,
}
}
pub fn insert(&mut self, table: String, row_id: u64, data: Vec<u8>) {
let op = WriteOp::Insert { table, row_id, data };
self.size += op.size();
self.operations.push(op);
self.read_only = false;
}
pub fn update(&mut self, table: String, row_id: u64, data: Vec<u8>) {
let op = WriteOp::Update { table, row_id, data };
self.size += op.size();
self.operations.push(op);
self.read_only = false;
}
pub fn delete(&mut self, table: String, row_id: u64) {
let op = WriteOp::Delete { table, row_id };
self.size += op.size();
self.operations.push(op);
self.read_only = false;
}
pub fn record_read(&mut self, table: &str, row_id: u64) {
self.read_set
.entry(table.to_string())
.or_insert_with(Vec::new)
.push(row_id);
}
pub fn operations(&self) -> &[WriteOp] {
&self.operations
}
pub fn take_operations(self) -> Vec<WriteOp> {
self.operations
}
pub fn add_operation(&mut self, op: WriteOp) {
self.size += op.size();
self.operations.push(op);
self.read_only = false;
}
pub fn operation_count(&self) -> usize {
self.operations.len()
}
pub fn size(&self) -> usize {
self.size
}
pub fn is_read_only(&self) -> bool {
self.read_only
}
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
pub fn len(&self) -> usize {
self.operations.len()
}
}
#[derive(Debug)]
pub struct CommitRequest {
pub txn_id: u64,
pub operations: Vec<WriteOp>,
pub response: Option<tokio::sync::oneshot::Sender<CommitResult>>,
}
#[derive(Debug, Clone)]
pub enum CommitResult {
Success {
commit_timestamp: u64,
ops_count: usize,
},
Failed {
error: String,
},
}
pub struct WriteCoordinator {
commit_sender: SyncSender<CommitRequest>,
commit_receiver: Mutex<Option<Receiver<CommitRequest>>>,
safety_level: IngestionSafetyLevel,
next_txn_id: AtomicU64,
current_timestamp: AtomicU64,
running: AtomicBool,
commits_total: AtomicU64,
ops_total: AtomicU64,
bytes_total: AtomicU64,
}
impl WriteCoordinator {
pub fn new(safety_level: IngestionSafetyLevel, queue_size: usize) -> Self {
let (sender, receiver) = mpsc::sync_channel(queue_size);
Self {
commit_sender: sender,
commit_receiver: Mutex::new(Some(receiver)),
safety_level,
next_txn_id: AtomicU64::new(1),
current_timestamp: AtomicU64::new(1),
running: AtomicBool::new(true),
commits_total: AtomicU64::new(0),
ops_total: AtomicU64::new(0),
bytes_total: AtomicU64::new(0),
}
}
pub fn begin_transaction(&self) -> TransactionBuffer {
let txn_id = self.next_txn_id.fetch_add(1, Ordering::Relaxed);
let read_timestamp = self.current_timestamp.load(Ordering::Acquire);
TransactionBuffer::new(txn_id, read_timestamp)
}
pub fn submit_commit(&self, buffer: TransactionBuffer) -> Result<(), String> {
if buffer.is_read_only() {
return Ok(());
}
let request = CommitRequest {
txn_id: buffer.txn_id,
operations: buffer.take_operations(),
response: None,
};
match self.commit_sender.try_send(request) {
Ok(()) => Ok(()),
Err(TrySendError::Full(_)) => {
Err("Commit queue full - backpressure".to_string())
}
Err(TrySendError::Disconnected(_)) => {
Err("Write coordinator shut down".to_string())
}
}
}
pub async fn commit_and_wait(&self, buffer: TransactionBuffer) -> CommitResult {
if buffer.is_read_only() {
return CommitResult::Success {
commit_timestamp: buffer.read_timestamp,
ops_count: 0,
};
}
let (tx, rx) = tokio::sync::oneshot::channel();
let request = CommitRequest {
txn_id: buffer.txn_id,
operations: buffer.take_operations(),
response: Some(tx),
};
if self.commit_sender.send(request).is_err() {
return CommitResult::Failed {
error: "Write coordinator shut down".to_string(),
};
}
match rx.await {
Ok(result) => result,
Err(_) => CommitResult::Failed {
error: "Response channel closed".to_string(),
},
}
}
pub fn take_receiver(&self) -> Option<Receiver<CommitRequest>> {
self.commit_receiver.lock().take()
}
pub fn advance_timestamp(&self) -> u64 {
self.current_timestamp.fetch_add(1, Ordering::AcqRel) + 1
}
pub fn current_timestamp(&self) -> u64 {
self.current_timestamp.load(Ordering::Acquire)
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
pub fn shutdown(&self) {
self.running.store(false, Ordering::Release);
}
pub fn record_commit(&self, ops_count: usize, bytes: usize) {
self.commits_total.fetch_add(1, Ordering::Relaxed);
self.ops_total.fetch_add(ops_count as u64, Ordering::Relaxed);
self.bytes_total.fetch_add(bytes as u64, Ordering::Relaxed);
}
pub fn stats(&self) -> (u64, u64, u64) {
(
self.commits_total.load(Ordering::Relaxed),
self.ops_total.load(Ordering::Relaxed),
self.bytes_total.load(Ordering::Relaxed),
)
}
}
pub struct BatchedCommitWorker {
receiver: Receiver<CommitRequest>,
safety_level: IngestionSafetyLevel,
pending: Vec<CommitRequest>,
pending_bytes: usize,
last_flush: Instant,
coordinator: Arc<WriteCoordinator>,
}
impl BatchedCommitWorker {
pub fn new(coordinator: Arc<WriteCoordinator>) -> Self {
#[allow(clippy::expect_used)]
let receiver = coordinator
.take_receiver()
.expect("Receiver already taken by another worker");
Self {
receiver,
safety_level: coordinator.safety_level.clone(),
pending: Vec::with_capacity(1024),
pending_bytes: 0,
last_flush: Instant::now(),
coordinator,
}
}
fn should_flush(&self) -> bool {
if self.pending.is_empty() {
return false;
}
match &self.safety_level {
IngestionSafetyLevel::Full => true,
IngestionSafetyLevel::Batched { batch_size, batch_timeout_ms } => {
self.pending.len() >= *batch_size
|| self.last_flush.elapsed() >= Duration::from_millis(*batch_timeout_ms)
}
IngestionSafetyLevel::Async { sync_interval_ms } => {
self.last_flush.elapsed() >= Duration::from_millis(*sync_interval_ms)
}
IngestionSafetyLevel::Unsafe { .. } => {
self.pending_bytes >= 1024 * 1024 }
}
}
pub fn process_batch<F>(&mut self, flush_fn: F) -> Result<usize, String>
where
F: FnOnce(&[CommitRequest], bool) -> Result<u64, String>,
{
let timeout = match &self.safety_level {
IngestionSafetyLevel::Full => Duration::from_millis(0),
IngestionSafetyLevel::Batched { batch_timeout_ms, .. } => {
Duration::from_millis(*batch_timeout_ms)
}
IngestionSafetyLevel::Async { sync_interval_ms } => {
Duration::from_millis(*sync_interval_ms)
}
IngestionSafetyLevel::Unsafe { .. } => Duration::from_millis(100),
};
match self.receiver.recv_timeout(timeout) {
Ok(request) => {
self.pending_bytes += request.operations.iter().map(|op| op.size()).sum::<usize>();
self.pending.push(request);
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
return Err("Coordinator disconnected".to_string());
}
}
while let Ok(request) = self.receiver.try_recv() {
self.pending_bytes += request.operations.iter().map(|op| op.size()).sum::<usize>();
self.pending.push(request);
if self.pending.len() >= 10000 {
break;
}
}
if self.should_flush() {
let sync = self.safety_level.sync_on_commit();
let commit_ts = flush_fn(&self.pending, sync)?;
let count = self.pending.len();
for request in self.pending.drain(..) {
if let Some(response) = request.response {
let _ = response.send(CommitResult::Success {
commit_timestamp: commit_ts,
ops_count: request.operations.len(),
});
}
self.coordinator.record_commit(
request.operations.len(),
request.operations.iter().map(|op| op.size()).sum(),
);
}
self.pending_bytes = 0;
self.last_flush = Instant::now();
return Ok(count);
}
Ok(0)
}
}
pub struct WriteBufferPool {
available: crossbeam::queue::ArrayQueue<TransactionBuffer>,
pool_size: usize,
next_txn_id: AtomicU64,
current_timestamp: AtomicU64,
}
impl WriteBufferPool {
pub fn new(pool_size: usize) -> Self {
let queue = crossbeam::queue::ArrayQueue::new(pool_size);
for i in 0..pool_size {
let buffer = TransactionBuffer::new(0, 0);
let _ = queue.push(buffer);
}
Self {
available: queue,
pool_size,
next_txn_id: AtomicU64::new(1),
current_timestamp: AtomicU64::new(1),
}
}
pub fn acquire(&self) -> TransactionBuffer {
let txn_id = self.next_txn_id.fetch_add(1, Ordering::Relaxed);
let read_ts = self.current_timestamp.load(Ordering::Acquire);
if let Some(mut buffer) = self.available.pop() {
buffer.txn_id = txn_id;
buffer.read_timestamp = read_ts;
buffer.operations.clear();
buffer.size = 0;
buffer.read_set.clear();
buffer.read_only = true;
return buffer;
}
TransactionBuffer::new(txn_id, read_ts)
}
pub fn release(&self, buffer: TransactionBuffer) {
let _ = self.available.push(buffer);
}
pub fn advance_timestamp(&self) -> u64 {
self.current_timestamp.fetch_add(1, Ordering::AcqRel) + 1
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_buffer() {
let mut buffer = TransactionBuffer::new(1, 100);
assert!(buffer.is_read_only());
assert!(buffer.is_empty());
buffer.insert("test".to_string(), 1, vec![1, 2, 3]);
assert!(!buffer.is_read_only());
assert_eq!(buffer.len(), 1);
buffer.update("test".to_string(), 1, vec![4, 5, 6]);
assert_eq!(buffer.len(), 2);
buffer.delete("test".to_string(), 2);
assert_eq!(buffer.len(), 3);
}
#[test]
fn test_write_coordinator() {
let coord = WriteCoordinator::new(IngestionSafetyLevel::Full, 1000);
let buffer1 = coord.begin_transaction();
let buffer2 = coord.begin_transaction();
assert_ne!(buffer1.txn_id, buffer2.txn_id);
assert!(buffer2.txn_id > buffer1.txn_id);
}
#[test]
fn test_buffer_pool() {
let pool = WriteBufferPool::new(10);
let b1 = pool.acquire();
let b2 = pool.acquire();
assert_ne!(b1.txn_id, b2.txn_id);
pool.release(b1);
let b3 = pool.acquire();
assert!(b3.is_empty());
}
}