use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use tokio::time::{interval, Duration};
use tokio_postgres::{NoTls, Error as PgError};
use deadpool_postgres::{Pool, Config, ManagerConfig, RecyclingMethod, Runtime};
use std::sync::Arc;
use std::collections::VecDeque;
use thiserror::Error;
use futures_util::SinkExt;
#[derive(Error, Debug)]
pub enum BatcherError {
#[error("PostgreSQL error: {0}")]
Pg(#[from] PgError),
#[error("Pool error: {0}")]
Pool(#[from] deadpool_postgres::PoolError),
#[error("Pool creation error: {0}")]
CreatePool(#[from] deadpool_postgres::CreatePoolError),
#[error("Serialization error: {0}")]
Serialization(#[from] bincode::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
type Result<T> = std::result::Result<T, BatcherError>;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub time: u64,
pub id: String,
pub content: String,
}
pub struct MsgBatcher {
pool: Pool,
buffer: Arc<Mutex<VecDeque<Message>>>,
batch_size: usize,
flush_interval: Duration,
max_buffer_size: usize,
running: Arc<Mutex<bool>>,
}
impl MsgBatcher {
pub async fn new(database_url: &str) -> Result<Self> {
let mut cfg = Config::new();
cfg.url = Some(database_url.to_string());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: 16,
..Default::default()
});
let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls)?;
let client = pool.get().await?;
client.execute(
"CREATE TABLE IF NOT EXISTS messages (
id BIGSERIAL PRIMARY KEY,
time BIGINT NOT NULL,
user_id TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)",
&[],
).await?;
client.execute(
"ALTER TABLE IF EXISTS messages SET UNLOGGED",
&[],
).await?;
Ok(Self {
pool,
buffer: Arc::new(Mutex::new(VecDeque::with_capacity(10000))),
batch_size: 5000,
flush_interval: Duration::from_secs(5),
max_buffer_size: 10000,
running: Arc::new(Mutex::new(true)),
})
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_flush_interval(mut self, seconds: u64) -> Self {
self.flush_interval = Duration::from_secs(seconds);
self
}
pub fn with_max_buffer(mut self, size: usize) -> Self {
self.max_buffer_size = size;
self
}
pub async fn append(&self, msg: Message) -> Result<()> {
let mut buffer = self.buffer.lock().await;
buffer.push_back(msg);
let len = buffer.len();
if len >= self.max_buffer_size {
drop(buffer);
self.flush().await?;
} else if len >= self.batch_size {
let batch: Vec<Message> = buffer.drain(..len).collect();
drop(buffer);
self.flush_batch(batch).await?;
}
Ok(())
}
pub async fn flush(&self) -> Result<()> {
let mut buffer = self.buffer.lock().await;
if buffer.is_empty() {
return Ok(());
}
let batch: Vec<Message> = buffer.drain(..).collect();
drop(buffer);
self.flush_batch(batch).await
}
pub async fn run_background(&self) -> Result<()> {
let buffer = Arc::clone(&self.buffer);
let pool = self.pool.clone();
let batch_size = self.batch_size;
let flush_interval = self.flush_interval;
let running = Arc::clone(&self.running);
let mut interval = interval(flush_interval);
tokio::spawn(async move {
loop {
interval.tick().await;
let should_stop = !*running.lock().await;
if should_stop {
break;
}
let mut guard = buffer.lock().await;
if guard.is_empty() {
continue;
}
let batches: Vec<Vec<Message>> = guard
.drain(..)
.collect::<Vec<Message>>()
.chunks(batch_size)
.map(|chunk| chunk.to_vec())
.collect();
drop(guard);
for batch in batches {
if let Err(e) = Self::bulk_insert(&pool, batch).await {
eprintln!("Failed to flush batch: {}", e);
}
}
}
});
Ok(())
}
pub async fn shutdown(&self) -> Result<()> {
let mut running = self.running.lock().await;
*running = false;
drop(running);
self.flush().await?;
Ok(())
}
async fn bulk_insert(pool: &Pool, messages: Vec<Message>) -> Result<()> {
if messages.is_empty() {
return Ok(());
}
let client = pool.get().await?;
let copy_stmt = "COPY messages (time, user_id, content) FROM STDIN (FORMAT CSV, DELIMITER ',')";
let copy_writer = client.copy_in(copy_stmt).await?;
tokio::pin!(copy_writer);
let mut batch_buffer = String::with_capacity(messages.len() * 256);
for msg in &messages {
batch_buffer.push_str(&msg.time.to_string());
batch_buffer.push(',');
batch_buffer.push_str(&msg.id);
batch_buffer.push_str(",\"");
for c in msg.content.chars() {
if c == '"' {
batch_buffer.push_str("\"\"");
} else {
batch_buffer.push(c);
}
}
batch_buffer.push_str("\"\n");
}
copy_writer.as_mut().send(bytes::Bytes::from(batch_buffer)).await?;
copy_writer.finish().await?;
Ok(())
}
async fn flush_batch(&self, messages: Vec<Message>) -> Result<()> {
if messages.is_empty() {
return Ok(());
}
Self::bulk_insert(&self.pool, messages).await
}
pub async fn buffer_size(&self) -> usize {
self.buffer.lock().await.len()
}
}