use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
pub type TableId = String;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BatchTicketId(u64);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub enabled: bool,
pub max_batch_size: usize,
pub max_wait_ms: u64,
pub max_batch_bytes: usize,
pub auto_flush: bool,
pub batch_tables: Vec<String>,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
enabled: true,
max_batch_size: 1000,
max_wait_ms: 10,
max_batch_bytes: 16 * 1024 * 1024, auto_flush: true,
batch_tables: Vec::new(), }
}
}
#[derive(Debug)]
pub struct InsertRequest {
pub table: String,
pub columns: Vec<String>,
pub values: Vec<Vec<String>>,
pub original_sql: String,
pub submitted_at: Instant,
response_tx: Option<oneshot::Sender<BatchResult>>,
}
#[derive(Debug, Clone)]
pub struct BatchResult {
pub ticket_id: BatchTicketId,
pub rows_inserted: u64,
pub success: bool,
pub error: Option<String>,
pub wait_time: Duration,
pub execution_time: Duration,
}
pub struct BatchTicket {
id: BatchTicketId,
rx: oneshot::Receiver<BatchResult>,
}
impl BatchTicket {
pub async fn wait(self) -> Result<BatchResult, BatchError> {
self.rx.await.map_err(|_| BatchError::ChannelClosed)
}
pub async fn wait_timeout(self, timeout: Duration) -> Result<BatchResult, BatchError> {
tokio::time::timeout(timeout, self.rx)
.await
.map_err(|_| BatchError::Timeout)?
.map_err(|_| BatchError::ChannelClosed)
}
pub fn id(&self) -> BatchTicketId {
self.id
}
}
#[derive(Debug, Clone)]
pub enum BatchError {
Disabled,
BatchFull,
Timeout,
ChannelClosed,
ExecutionFailed(String),
}
impl std::fmt::Display for BatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disabled => write!(f, "Batching is disabled"),
Self::BatchFull => write!(f, "Batch is full"),
Self::Timeout => write!(f, "Batch timeout"),
Self::ChannelClosed => write!(f, "Channel closed"),
Self::ExecutionFailed(e) => write!(f, "Execution failed: {}", e),
}
}
}
impl std::error::Error for BatchError {}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatchStats {
pub inserts_received: u64,
pub rows_received: u64,
pub batches_flushed: u64,
pub rows_inserted: u64,
pub avg_batch_size: f64,
pub avg_wait_time_ms: f64,
pub avg_execution_time_ms: f64,
pub size_triggered_flushes: u64,
pub time_triggered_flushes: u64,
pub flush_failures: u64,
}
struct PendingBatch {
requests: Vec<InsertRequest>,
row_count: usize,
byte_count: usize,
first_submitted: Instant,
}
impl PendingBatch {
fn new() -> Self {
Self {
requests: Vec::with_capacity(100),
row_count: 0,
byte_count: 0,
first_submitted: Instant::now(),
}
}
fn add(&mut self, request: InsertRequest) {
let row_count = request.values.len();
let byte_estimate = request.original_sql.len();
if self.requests.is_empty() {
self.first_submitted = request.submitted_at;
}
self.row_count += row_count;
self.byte_count += byte_estimate;
self.requests.push(request);
}
fn is_empty(&self) -> bool {
self.requests.is_empty()
}
fn should_flush(&self, config: &BatchConfig) -> bool {
self.row_count >= config.max_batch_size
|| self.byte_count >= config.max_batch_bytes
|| self.first_submitted.elapsed().as_millis() as u64 >= config.max_wait_ms
}
fn drain(&mut self) -> (Vec<InsertRequest>, usize) {
let row_count = self.row_count;
self.row_count = 0;
self.byte_count = 0;
(std::mem::take(&mut self.requests), row_count)
}
}
pub struct InsertBatcher {
config: BatchConfig,
pending: DashMap<TableId, PendingBatch>,
next_ticket_id: AtomicU64,
stats: Arc<parking_lot::RwLock<BatchStats>>,
shutdown: AtomicBool,
backend: Option<crate::backend::BackendConfig>,
}
struct SealedBatch {
requests: Vec<InsertRequest>,
row_count: usize,
sql: String,
}
impl InsertBatcher {
pub fn new(config: BatchConfig) -> Self {
Self {
config,
pending: DashMap::new(),
next_ticket_id: AtomicU64::new(1),
stats: Arc::new(parking_lot::RwLock::new(BatchStats::default())),
shutdown: AtomicBool::new(false),
backend: None,
}
}
pub fn with_backend(mut self, backend: crate::backend::BackendConfig) -> Self {
self.backend = Some(backend);
self
}
pub fn add(
self: &Arc<Self>,
table: String,
columns: Vec<String>,
values: Vec<Vec<String>>,
original_sql: String,
) -> Result<BatchTicket, BatchError> {
if !self.config.enabled {
return Err(BatchError::Disabled);
}
if self.shutdown.load(Ordering::Relaxed) {
return Err(BatchError::ExecutionFailed("Batcher shutdown".to_string()));
}
if !self.config.batch_tables.is_empty() && !self.config.batch_tables.contains(&table) {
return Err(BatchError::Disabled);
}
let ticket_id = BatchTicketId(self.next_ticket_id.fetch_add(1, Ordering::Relaxed));
let (tx, rx) = oneshot::channel();
let row_count = values.len();
let request = InsertRequest {
table: table.clone(),
columns,
values,
original_sql,
submitted_at: Instant::now(),
response_tx: Some(tx),
};
{
let mut stats = self.stats.write();
stats.inserts_received += 1;
stats.rows_received += row_count as u64;
}
let should_flush = {
let mut batch = self
.pending
.entry(table.clone())
.or_insert_with(PendingBatch::new);
batch.add(request);
batch.should_flush(&self.config)
};
if should_flush {
if let Some(sealed) = self.seal(&table) {
let me = Arc::clone(self);
tokio::spawn(async move {
me.execute_sealed(sealed).await;
});
}
}
Ok(BatchTicket { id: ticket_id, rx })
}
fn seal(&self, table: &str) -> Option<SealedBatch> {
let (_, mut batch) = self.pending.remove(table)?;
if batch.is_empty() {
return None;
}
let (requests, row_count) = batch.drain();
let sql = self.combine_inserts(&requests);
Some(SealedBatch {
requests,
row_count,
sql,
})
}
async fn execute_sealed(&self, sealed: SealedBatch) {
let SealedBatch {
requests,
row_count,
sql,
} = sealed;
let execution_start = Instant::now();
let (success, error) = match &self.backend {
Some(cfg) => match crate::backend::BackendClient::connect(cfg).await {
Ok(mut client) => {
let outcome = client.execute(&sql).await;
client.close().await;
match outcome {
Ok(_tag) => (true, None),
Err(e) => (false, Some(format!("execute: {}", e))),
}
}
Err(e) => (false, Some(format!("connect: {}", e))),
},
None => (false, Some("no backend configured".to_string())),
};
let execution_time = execution_start.elapsed();
{
let mut stats = self.stats.write();
stats.batches_flushed += 1;
if success {
stats.rows_inserted += row_count as u64;
} else {
stats.flush_failures += 1;
}
if stats.batches_flushed == 1 {
stats.avg_batch_size = row_count as f64;
} else {
stats.avg_batch_size = stats.avg_batch_size * 0.9 + row_count as f64 * 0.1;
}
let exec_ms = execution_time.as_millis() as f64;
if stats.batches_flushed == 1 {
stats.avg_execution_time_ms = exec_ms;
} else {
stats.avg_execution_time_ms = stats.avg_execution_time_ms * 0.9 + exec_ms * 0.1;
}
}
for mut req in requests {
let wait_time = req
.submitted_at
.elapsed()
.checked_sub(execution_time)
.unwrap_or_default();
if let Some(tx) = req.response_tx.take() {
let _ = tx.send(BatchResult {
ticket_id: BatchTicketId(0), rows_inserted: if success { req.values.len() as u64 } else { 0 },
success,
error: error.clone(),
wait_time,
execution_time,
});
}
}
}
pub async fn flush_batch(&self, table: &str) {
if let Some(sealed) = self.seal(table) {
self.execute_sealed(sealed).await;
}
}
fn combine_inserts(&self, requests: &[InsertRequest]) -> String {
if requests.is_empty() {
return String::new();
}
let first = &requests[0];
let table = &first.table;
let columns = &first.columns;
let mut sql = format!("INSERT INTO {} ({}) VALUES ", table, columns.join(", "));
let mut value_parts: Vec<String> = Vec::new();
for req in requests {
for row in &req.values {
value_parts.push(format!("({})", row.join(", ")));
}
}
sql.push_str(&value_parts.join(", "));
sql
}
pub async fn flush_all(&self) {
let tables: Vec<TableId> = self.pending.iter().map(|r| r.key().clone()).collect();
for table in tables {
self.flush_batch(&table).await;
}
}
pub fn batch_size(&self, table: &str) -> usize {
self.pending.get(table).map(|b| b.row_count).unwrap_or(0)
}
pub fn stats(&self) -> BatchStats {
self.stats.read().clone()
}
pub async fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
self.flush_all().await;
}
pub fn start_auto_flush(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let interval = Duration::from_millis(self.config.max_wait_ms);
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
if self.shutdown.load(Ordering::Relaxed) {
break;
}
let tables: Vec<TableId> = self
.pending
.iter()
.filter(|r| {
r.first_submitted.elapsed().as_millis() as u64 >= self.config.max_wait_ms
})
.map(|r| r.key().clone())
.collect();
for table in tables {
self.flush_batch(&table).await;
self.stats.write().time_triggered_flushes += 1;
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_batch_add() {
let batcher = Arc::new(InsertBatcher::new(BatchConfig::default()));
let _ticket = batcher
.add(
"users".to_string(),
vec!["id".to_string(), "name".to_string()],
vec![vec!["1".to_string(), "'Alice'".to_string()]],
"INSERT INTO users (id, name) VALUES (1, 'Alice')".to_string(),
)
.unwrap();
assert_eq!(batcher.batch_size("users"), 1);
}
#[tokio::test]
async fn test_batch_flush_on_size() {
let config = BatchConfig {
max_batch_size: 2,
..Default::default()
};
let batcher = Arc::new(InsertBatcher::new(config));
batcher
.add(
"users".to_string(),
vec!["id".to_string()],
vec![vec!["1".to_string()]],
"INSERT INTO users VALUES (1)".to_string(),
)
.unwrap();
assert_eq!(batcher.batch_size("users"), 1);
batcher
.add(
"users".to_string(),
vec!["id".to_string()],
vec![vec!["2".to_string()]],
"INSERT INTO users VALUES (2)".to_string(),
)
.unwrap();
assert_eq!(batcher.batch_size("users"), 0);
}
#[test]
fn test_combine_inserts() {
let batcher = InsertBatcher::new(BatchConfig::default());
let requests = vec![
InsertRequest {
table: "users".to_string(),
columns: vec!["id".to_string(), "name".to_string()],
values: vec![vec!["1".to_string(), "'Alice'".to_string()]],
original_sql: String::new(),
submitted_at: Instant::now(),
response_tx: None,
},
InsertRequest {
table: "users".to_string(),
columns: vec!["id".to_string(), "name".to_string()],
values: vec![vec!["2".to_string(), "'Bob'".to_string()]],
original_sql: String::new(),
submitted_at: Instant::now(),
response_tx: None,
},
];
let combined = batcher.combine_inserts(&requests);
assert!(combined.contains("INSERT INTO users"));
assert!(combined.contains("(1, 'Alice')"));
assert!(combined.contains("(2, 'Bob')"));
}
#[test]
fn test_batch_stats() {
let batcher = Arc::new(InsertBatcher::new(BatchConfig::default()));
batcher
.add(
"users".to_string(),
vec!["id".to_string()],
vec![vec!["1".to_string()], vec!["2".to_string()]],
"INSERT INTO users VALUES (1), (2)".to_string(),
)
.unwrap();
let stats = batcher.stats();
assert_eq!(stats.inserts_received, 1);
assert_eq!(stats.rows_received, 2);
}
#[tokio::test]
async fn flush_executes_against_live_backend() {
use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
let addr = match std::env::var("HELIOS_LIVE_PG") {
Ok(a) if !a.is_empty() => a,
_ => {
eprintln!("skipping flush_executes_against_live_backend: set HELIOS_LIVE_PG");
return;
}
};
let (host, port_s) = addr.rsplit_once(':').unwrap();
let port: u16 = port_s.parse().unwrap();
let user = std::env::var("HELIOS_LIVE_USER").unwrap_or_else(|_| "bench".into());
let pass = std::env::var("HELIOS_LIVE_PASS").unwrap_or_else(|_| "benchpass".into());
let db = std::env::var("HELIOS_LIVE_DB").unwrap_or_else(|_| "benchdb".into());
let cfg = BackendConfig {
host: host.to_string(),
port,
user,
password: Some(pass),
database: Some(db),
application_name: Some("helios-batch-test".into()),
tls_mode: TlsMode::Disable,
connect_timeout: Duration::from_secs(5),
query_timeout: Duration::from_secs(5),
tls_config: default_client_config(),
};
let mut seed = BackendClient::connect(&cfg).await.expect("connect seed");
seed.execute("DROP TABLE IF EXISTS batch_probe")
.await
.unwrap();
seed.execute("CREATE TABLE batch_probe(id int, total numeric)")
.await
.unwrap();
seed.close().await;
let batcher = Arc::new(
InsertBatcher::new(BatchConfig {
max_batch_size: 1000,
max_wait_ms: 60_000,
..Default::default()
})
.with_backend(cfg.clone()),
);
batcher
.add(
"batch_probe".to_string(),
vec!["id".to_string(), "total".to_string()],
vec![
vec!["1".to_string(), "99.99".to_string()],
vec!["2".to_string(), "12.50".to_string()],
],
String::new(),
)
.unwrap();
batcher.flush_batch("batch_probe").await;
let mut verify = BackendClient::connect(&cfg).await.expect("connect verify");
let n = verify
.query_scalar("SELECT count(*) AS n FROM batch_probe")
.await
.unwrap()
.as_i64("n")
.unwrap()
.unwrap_or(0);
let _ = verify.execute("DROP TABLE IF EXISTS batch_probe").await;
verify.close().await;
assert_eq!(n, 2, "expected 2 batched rows to land, found {}", n);
assert_eq!(batcher.stats().rows_inserted, 2);
assert_eq!(batcher.stats().flush_failures, 0);
}
}