use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::sync::oneshot;
use serde::{Deserialize, Serialize};
pub type TableId = String;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BatchTicketId(u64);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub enabled: bool,
pub max_batch_size: usize,
pub max_wait_ms: u64,
pub max_batch_bytes: usize,
pub auto_flush: bool,
pub batch_tables: Vec<String>,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
enabled: true,
max_batch_size: 1000,
max_wait_ms: 10,
max_batch_bytes: 16 * 1024 * 1024, auto_flush: true,
batch_tables: Vec::new(), }
}
}
#[derive(Debug)]
pub struct InsertRequest {
pub table: String,
pub columns: Vec<String>,
pub values: Vec<Vec<String>>,
pub original_sql: String,
pub submitted_at: Instant,
response_tx: Option<oneshot::Sender<BatchResult>>,
}
#[derive(Debug, Clone)]
pub struct BatchResult {
pub ticket_id: BatchTicketId,
pub rows_inserted: u64,
pub success: bool,
pub error: Option<String>,
pub wait_time: Duration,
pub execution_time: Duration,
}
pub struct BatchTicket {
id: BatchTicketId,
rx: oneshot::Receiver<BatchResult>,
}
impl BatchTicket {
pub async fn wait(self) -> Result<BatchResult, BatchError> {
self.rx.await.map_err(|_| BatchError::ChannelClosed)
}
pub async fn wait_timeout(self, timeout: Duration) -> Result<BatchResult, BatchError> {
tokio::time::timeout(timeout, self.rx)
.await
.map_err(|_| BatchError::Timeout)?
.map_err(|_| BatchError::ChannelClosed)
}
pub fn id(&self) -> BatchTicketId {
self.id
}
}
#[derive(Debug, Clone)]
pub enum BatchError {
Disabled,
BatchFull,
Timeout,
ChannelClosed,
ExecutionFailed(String),
}
impl std::fmt::Display for BatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disabled => write!(f, "Batching is disabled"),
Self::BatchFull => write!(f, "Batch is full"),
Self::Timeout => write!(f, "Batch timeout"),
Self::ChannelClosed => write!(f, "Channel closed"),
Self::ExecutionFailed(e) => write!(f, "Execution failed: {}", e),
}
}
}
impl std::error::Error for BatchError {}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatchStats {
pub inserts_received: u64,
pub rows_received: u64,
pub batches_flushed: u64,
pub rows_inserted: u64,
pub avg_batch_size: f64,
pub avg_wait_time_ms: f64,
pub avg_execution_time_ms: f64,
pub size_triggered_flushes: u64,
pub time_triggered_flushes: u64,
}
struct PendingBatch {
requests: Vec<InsertRequest>,
row_count: usize,
byte_count: usize,
first_submitted: Instant,
}
impl PendingBatch {
fn new() -> Self {
Self {
requests: Vec::with_capacity(100),
row_count: 0,
byte_count: 0,
first_submitted: Instant::now(),
}
}
fn add(&mut self, request: InsertRequest) {
let row_count = request.values.len();
let byte_estimate = request.original_sql.len();
if self.requests.is_empty() {
self.first_submitted = request.submitted_at;
}
self.row_count += row_count;
self.byte_count += byte_estimate;
self.requests.push(request);
}
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn should_flush(&self, config: &BatchConfig) -> bool {
self.row_count >= config.max_batch_size ||
self.byte_count >= config.max_batch_bytes ||
self.first_submitted.elapsed().as_millis() as u64 >= config.max_wait_ms
}
fn drain(&mut self) -> (Vec<InsertRequest>, usize) {
let row_count = self.row_count;
self.row_count = 0;
self.byte_count = 0;
(std::mem::take(&mut self.requests), row_count)
}
}
pub struct InsertBatcher {
config: BatchConfig,
pending: DashMap<TableId, PendingBatch>,
next_ticket_id: AtomicU64,
stats: Arc<parking_lot::RwLock<BatchStats>>,
shutdown: AtomicBool,
}
impl InsertBatcher {
pub fn new(config: BatchConfig) -> Self {
Self {
config,
pending: DashMap::new(),
next_ticket_id: AtomicU64::new(1),
stats: Arc::new(parking_lot::RwLock::new(BatchStats::default())),
shutdown: AtomicBool::new(false),
}
}
pub fn add(
&self,
table: String,
columns: Vec<String>,
values: Vec<Vec<String>>,
original_sql: String,
) -> Result<BatchTicket, BatchError> {
if !self.config.enabled {
return Err(BatchError::Disabled);
}
if self.shutdown.load(Ordering::Relaxed) {
return Err(BatchError::ExecutionFailed("Batcher shutdown".to_string()));
}
if !self.config.batch_tables.is_empty() &&
!self.config.batch_tables.contains(&table)
{
return Err(BatchError::Disabled);
}
let ticket_id = BatchTicketId(self.next_ticket_id.fetch_add(1, Ordering::Relaxed));
let (tx, rx) = oneshot::channel();
let row_count = values.len();
let request = InsertRequest {
table: table.clone(),
columns,
values,
original_sql,
submitted_at: Instant::now(),
response_tx: Some(tx),
};
{
let mut stats = self.stats.write();
stats.inserts_received += 1;
stats.rows_received += row_count as u64;
}
let should_flush = {
let mut batch = self.pending.entry(table.clone()).or_insert_with(PendingBatch::new);
batch.add(request);
batch.should_flush(&self.config)
};
if should_flush {
self.flush_batch(&table);
}
Ok(BatchTicket { id: ticket_id, rx })
}
pub fn flush_batch(&self, table: &str) {
if let Some((_, mut batch)) = self.pending.remove(table) {
if batch.is_empty() {
return;
}
let (requests, row_count) = batch.drain();
let execution_start = Instant::now();
let _combined_sql = self.combine_inserts(&requests);
let success = true; let error: Option<String> = None;
let execution_time = execution_start.elapsed();
{
let mut stats = self.stats.write();
stats.batches_flushed += 1;
stats.rows_inserted += row_count as u64;
if stats.batches_flushed == 1 {
stats.avg_batch_size = row_count as f64;
} else {
stats.avg_batch_size = stats.avg_batch_size * 0.9 + row_count as f64 * 0.1;
}
let exec_ms = execution_time.as_millis() as f64;
if stats.batches_flushed == 1 {
stats.avg_execution_time_ms = exec_ms;
} else {
stats.avg_execution_time_ms = stats.avg_execution_time_ms * 0.9 + exec_ms * 0.1;
}
}
for mut req in requests {
let wait_time = req.submitted_at.elapsed() - execution_time;
if let Some(tx) = req.response_tx.take() {
let _ = tx.send(BatchResult {
ticket_id: BatchTicketId(0), rows_inserted: req.values.len() as u64,
success,
error: error.clone(),
wait_time,
execution_time,
});
}
}
}
}
fn combine_inserts(&self, requests: &[InsertRequest]) -> String {
if requests.is_empty() {
return String::new();
}
let first = &requests[0];
let table = &first.table;
let columns = &first.columns;
let mut sql = format!(
"INSERT INTO {} ({}) VALUES ",
table,
columns.join(", ")
);
let mut value_parts: Vec<String> = Vec::new();
for req in requests {
for row in &req.values {
value_parts.push(format!("({})", row.join(", ")));
}
}
sql.push_str(&value_parts.join(", "));
sql
}
pub fn flush_all(&self) {
let tables: Vec<TableId> = self.pending.iter().map(|r| r.key().clone()).collect();
for table in tables {
self.flush_batch(&table);
}
}
pub fn batch_size(&self, table: &str) -> usize {
self.pending
.get(table)
.map(|b| b.row_count)
.unwrap_or(0)
}
pub fn stats(&self) -> BatchStats {
self.stats.read().clone()
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
self.flush_all();
}
pub fn start_auto_flush(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let interval = Duration::from_millis(self.config.max_wait_ms);
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
if self.shutdown.load(Ordering::Relaxed) {
break;
}
let tables: Vec<TableId> = self.pending
.iter()
.filter(|r| {
r.first_submitted.elapsed().as_millis() as u64 >= self.config.max_wait_ms
})
.map(|r| r.key().clone())
.collect();
for table in tables {
self.flush_batch(&table);
self.stats.write().time_triggered_flushes += 1;
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_batch_add() {
let batcher = InsertBatcher::new(BatchConfig::default());
let ticket = batcher.add(
"users".to_string(),
vec!["id".to_string(), "name".to_string()],
vec![vec!["1".to_string(), "'Alice'".to_string()]],
"INSERT INTO users (id, name) VALUES (1, 'Alice')".to_string(),
).unwrap();
assert_eq!(batcher.batch_size("users"), 1);
}
#[tokio::test]
async fn test_batch_flush_on_size() {
let config = BatchConfig {
max_batch_size: 2,
..Default::default()
};
let batcher = InsertBatcher::new(config);
batcher.add(
"users".to_string(),
vec!["id".to_string()],
vec![vec!["1".to_string()]],
"INSERT INTO users VALUES (1)".to_string(),
).unwrap();
assert_eq!(batcher.batch_size("users"), 1);
batcher.add(
"users".to_string(),
vec!["id".to_string()],
vec![vec!["2".to_string()]],
"INSERT INTO users VALUES (2)".to_string(),
).unwrap();
assert_eq!(batcher.batch_size("users"), 0);
}
#[test]
fn test_combine_inserts() {
let batcher = InsertBatcher::new(BatchConfig::default());
let requests = vec![
InsertRequest {
table: "users".to_string(),
columns: vec!["id".to_string(), "name".to_string()],
values: vec![vec!["1".to_string(), "'Alice'".to_string()]],
original_sql: String::new(),
submitted_at: Instant::now(),
response_tx: None,
},
InsertRequest {
table: "users".to_string(),
columns: vec!["id".to_string(), "name".to_string()],
values: vec![vec!["2".to_string(), "'Bob'".to_string()]],
original_sql: String::new(),
submitted_at: Instant::now(),
response_tx: None,
},
];
let combined = batcher.combine_inserts(&requests);
assert!(combined.contains("INSERT INTO users"));
assert!(combined.contains("(1, 'Alice')"));
assert!(combined.contains("(2, 'Bob')"));
}
#[test]
fn test_batch_stats() {
let batcher = InsertBatcher::new(BatchConfig::default());
batcher.add(
"users".to_string(),
vec!["id".to_string()],
vec![vec!["1".to_string()], vec!["2".to_string()]],
"INSERT INTO users VALUES (1), (2)".to_string(),
).unwrap();
let stats = batcher.stats();
assert_eq!(stats.inserts_received, 1);
assert_eq!(stats.rows_received, 2);
}
}