1use serde::{Deserialize, Serialize};
2use tokio::sync::Mutex;
3use tokio::time::{interval, Duration};
4use tokio_postgres::{NoTls, Error as PgError};
5use deadpool_postgres::{Pool, Config, ManagerConfig, RecyclingMethod, Runtime};
6use std::sync::Arc;
7use std::collections::VecDeque;
8use thiserror::Error;
9use futures_util::SinkExt; #[derive(Error, Debug)]
12pub enum BatcherError {
13 #[error("PostgreSQL error: {0}")]
14 Pg(#[from] PgError),
15 #[error("Pool error: {0}")]
16 Pool(#[from] deadpool_postgres::PoolError),
17 #[error("Pool creation error: {0}")]
18 CreatePool(#[from] deadpool_postgres::CreatePoolError),
19 #[error("Serialization error: {0}")]
20 Serialization(#[from] bincode::Error),
21 #[error("IO error: {0}")]
22 Io(#[from] std::io::Error),
23}
24
25type Result<T> = std::result::Result<T, BatcherError>;
26
27#[derive(Serialize, Deserialize, Debug, Clone)]
28pub struct Message {
29 pub time: u64,
30 pub id: String,
31 pub content: String,
32}
33
34pub struct MsgBatcher {
35 pool: Pool,
36 buffer: Arc<Mutex<VecDeque<Message>>>,
37 batch_size: usize,
38 flush_interval: Duration,
39 max_buffer_size: usize,
40 running: Arc<Mutex<bool>>,
41}
42
43impl MsgBatcher {
44 pub async fn new(database_url: &str) -> Result<Self> {
46 let mut cfg = Config::new();
47 cfg.url = Some(database_url.to_string());
48
49 cfg.manager = Some(ManagerConfig {
51 recycling_method: RecyclingMethod::Fast,
52 });
53
54 cfg.pool = Some(deadpool_postgres::PoolConfig {
56 max_size: 16,
57 ..Default::default()
58 });
59
60 let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls)?;
61
62 let client = pool.get().await?;
64 client.execute(
65 "CREATE TABLE IF NOT EXISTS messages (
66 id BIGSERIAL PRIMARY KEY,
67 time BIGINT NOT NULL,
68 user_id TEXT NOT NULL,
69 content TEXT NOT NULL,
70 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
71 )",
72 &[],
73 ).await?;
74
75 client.execute(
77 "ALTER TABLE IF EXISTS messages SET UNLOGGED",
78 &[],
79 ).await?;
80
81 Ok(Self {
82 pool,
83 buffer: Arc::new(Mutex::new(VecDeque::with_capacity(10000))),
84 batch_size: 5000,
85 flush_interval: Duration::from_secs(5),
86 max_buffer_size: 10000,
87 running: Arc::new(Mutex::new(true)),
88 })
89 }
90
91 pub fn with_batch_size(mut self, size: usize) -> Self {
93 self.batch_size = size;
94 self
95 }
96
97 pub fn with_flush_interval(mut self, seconds: u64) -> Self {
99 self.flush_interval = Duration::from_secs(seconds);
100 self
101 }
102
103 pub fn with_max_buffer(mut self, size: usize) -> Self {
105 self.max_buffer_size = size;
106 self
107 }
108
109 pub async fn append(&self, msg: Message) -> Result<()> {
111 let mut buffer = self.buffer.lock().await;
112 buffer.push_back(msg);
113
114 let len = buffer.len();
115
116 if len >= self.max_buffer_size {
118 drop(buffer);
119 self.flush().await?;
120 } else if len >= self.batch_size {
121 let batch: Vec<Message> = buffer.drain(..len).collect();
122 drop(buffer);
123 self.flush_batch(batch).await?;
124 }
125
126 Ok(())
127 }
128
129 pub async fn flush(&self) -> Result<()> {
131 let mut buffer = self.buffer.lock().await;
132 if buffer.is_empty() {
133 return Ok(());
134 }
135
136 let batch: Vec<Message> = buffer.drain(..).collect();
137 drop(buffer);
138
139 self.flush_batch(batch).await
140 }
141
142 pub async fn run_background(&self) -> Result<()> {
144 let buffer = Arc::clone(&self.buffer);
145 let pool = self.pool.clone();
146 let batch_size = self.batch_size;
147 let flush_interval = self.flush_interval;
148 let running = Arc::clone(&self.running);
149 let mut interval = interval(flush_interval);
150
151 tokio::spawn(async move {
152 loop {
153 interval.tick().await;
154
155 let should_stop = !*running.lock().await;
157 if should_stop {
158 break;
159 }
160
161 let mut guard = buffer.lock().await;
162 if guard.is_empty() {
163 continue;
164 }
165
166 let batches: Vec<Vec<Message>> = guard
168 .drain(..)
169 .collect::<Vec<Message>>()
170 .chunks(batch_size)
171 .map(|chunk| chunk.to_vec())
172 .collect();
173 drop(guard);
174
175 for batch in batches {
177 if let Err(e) = Self::bulk_insert(&pool, batch).await {
178 eprintln!("Failed to flush batch: {}", e);
179 }
180 }
181 }
182 });
183
184 Ok(())
185 }
186
187 pub async fn shutdown(&self) -> Result<()> {
189 let mut running = self.running.lock().await;
190 *running = false;
191 drop(running);
192
193 self.flush().await?;
195 Ok(())
196 }
197
198 async fn bulk_insert(pool: &Pool, messages: Vec<Message>) -> Result<()> {
200 if messages.is_empty() {
201 return Ok(());
202 }
203
204 let client = pool.get().await?;
205
206 let copy_stmt = "COPY messages (time, user_id, content) FROM STDIN (FORMAT CSV, DELIMITER ',')";
208 let copy_writer = client.copy_in(copy_stmt).await?;
209
210 tokio::pin!(copy_writer);
212
213 let mut batch_buffer = String::with_capacity(messages.len() * 256);
215
216 for msg in &messages {
217 batch_buffer.push_str(&msg.time.to_string());
218 batch_buffer.push(',');
219 batch_buffer.push_str(&msg.id);
220 batch_buffer.push_str(",\"");
221
222 for c in msg.content.chars() {
224 if c == '"' {
225 batch_buffer.push_str("\"\"");
226 } else {
227 batch_buffer.push(c);
228 }
229 }
230 batch_buffer.push_str("\"\n");
231 }
232
233 copy_writer.as_mut().send(bytes::Bytes::from(batch_buffer)).await?;
235 copy_writer.finish().await?;
236
237 Ok(())
238 }
239
240 async fn flush_batch(&self, messages: Vec<Message>) -> Result<()> {
241 if messages.is_empty() {
242 return Ok(());
243 }
244 Self::bulk_insert(&self.pool, messages).await
245 }
246
247 pub async fn buffer_size(&self) -> usize {
249 self.buffer.lock().await.len()
250 }
251}