Skip to main content

fr_rust/ws/
batcher.rs

1use serde::{Deserialize, Serialize};
2use tokio::sync::Mutex;
3use tokio::time::{interval, Duration};
4use tokio_postgres::{NoTls, Error as PgError};
5use deadpool_postgres::{Pool, Config, ManagerConfig, RecyclingMethod, Runtime};
6use std::sync::Arc;
7use std::collections::VecDeque;
8use thiserror::Error;
9use futures_util::SinkExt; // Required for copy_writer.send()
10
11#[derive(Error, Debug)]
12pub enum BatcherError {
13    #[error("PostgreSQL error: {0}")]
14    Pg(#[from] PgError),
15    #[error("Pool error: {0}")]
16    Pool(#[from] deadpool_postgres::PoolError),
17    #[error("Pool creation error: {0}")]
18    CreatePool(#[from] deadpool_postgres::CreatePoolError),
19    #[error("Serialization error: {0}")]
20    Serialization(#[from] bincode::Error),
21    #[error("IO error: {0}")]
22    Io(#[from] std::io::Error),
23}
24
25type Result<T> = std::result::Result<T, BatcherError>;
26
27#[derive(Serialize, Deserialize, Debug, Clone)]
28pub struct Message {
29    pub time: u64,
30    pub id: String,
31    pub content: String,
32}
33
34pub struct MsgBatcher {
35    pool: Pool,
36    buffer: Arc<Mutex<VecDeque<Message>>>,
37    batch_size: usize,
38    flush_interval: Duration,
39    max_buffer_size: usize,
40    running: Arc<Mutex<bool>>,
41}
42
43impl MsgBatcher {
44    /// Create new batcher with PostgreSQL connection string
45    pub async fn new(database_url: &str) -> Result<Self> {
46        let mut cfg = Config::new();
47        cfg.url = Some(database_url.to_string());
48        
49        // Correctly assign recycling_method to ManagerConfig
50        cfg.manager = Some(ManagerConfig {
51            recycling_method: RecyclingMethod::Fast,
52        });
53
54        // Use standard initialization to avoid missing QueueMode types
55        cfg.pool = Some(deadpool_postgres::PoolConfig {
56            max_size: 16,
57            ..Default::default()
58        });
59        
60        let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls)?;
61        
62        // Initialize table
63        let client = pool.get().await?;
64        client.execute(
65            "CREATE TABLE IF NOT EXISTS messages (
66                id BIGSERIAL PRIMARY KEY,
67                time BIGINT NOT NULL,
68                user_id TEXT NOT NULL,
69                content TEXT NOT NULL,
70                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
71            )",
72            &[],
73        ).await?;
74        
75        // Check if table is UNLOGGED for max performance
76        client.execute(
77            "ALTER TABLE IF EXISTS messages SET UNLOGGED",
78            &[],
79        ).await?;
80        
81        Ok(Self {
82            pool,
83            buffer: Arc::new(Mutex::new(VecDeque::with_capacity(10000))),
84            batch_size: 5000,
85            flush_interval: Duration::from_secs(5),
86            max_buffer_size: 10000,
87            running: Arc::new(Mutex::new(true)),
88        })
89    }
90
91    /// Configure batch size (default: 5000)
92    pub fn with_batch_size(mut self, size: usize) -> Self {
93        self.batch_size = size;
94        self
95    }
96
97    /// Configure flush interval in seconds (default: 5)
98    pub fn with_flush_interval(mut self, seconds: u64) -> Self {
99        self.flush_interval = Duration::from_secs(seconds);
100        self
101    }
102
103    /// Configure max buffer size (default: 10000)
104    pub fn with_max_buffer(mut self, size: usize) -> Self {
105        self.max_buffer_size = size;
106        self
107    }
108
109    /// Append message to buffer (non-blocking, ~microseconds)
110    pub async fn append(&self, msg: Message) -> Result<()> {
111        let mut buffer = self.buffer.lock().await;
112        buffer.push_back(msg);
113        
114        let len = buffer.len();
115        
116        // Emergency flush if buffer is too large
117        if len >= self.max_buffer_size {
118            drop(buffer);
119            self.flush().await?;
120        } else if len >= self.batch_size {
121            let batch: Vec<Message> = buffer.drain(..len).collect();
122            drop(buffer);
123            self.flush_batch(batch).await?;
124        }
125        
126        Ok(())
127    }
128
129    /// Manually flush all pending messages
130    pub async fn flush(&self) -> Result<()> {
131        let mut buffer = self.buffer.lock().await;
132        if buffer.is_empty() {
133            return Ok(());
134        }
135        
136        let batch: Vec<Message> = buffer.drain(..).collect();
137        drop(buffer);
138        
139        self.flush_batch(batch).await
140    }
141
142    /// Background worker - call this in your main function
143    pub async fn run_background(&self) -> Result<()> {
144        let buffer = Arc::clone(&self.buffer);
145        let pool = self.pool.clone();
146        let batch_size = self.batch_size;
147        let flush_interval = self.flush_interval;
148        let running = Arc::clone(&self.running);
149        let mut interval = interval(flush_interval);
150
151        tokio::spawn(async move {
152            loop {
153                interval.tick().await;
154                
155                // Check if we should stop
156                let should_stop = !*running.lock().await;
157                if should_stop {
158                    break;
159                }
160                
161                let mut guard = buffer.lock().await;
162                if guard.is_empty() {
163                    continue;
164                }
165                
166                // Drain in chunks for efficiency
167                let batches: Vec<Vec<Message>> = guard
168                    .drain(..)
169                    .collect::<Vec<Message>>()
170                    .chunks(batch_size)
171                    .map(|chunk| chunk.to_vec())
172                    .collect();
173                drop(guard);
174                
175                // Process each chunk
176                for batch in batches {
177                    if let Err(e) = Self::bulk_insert(&pool, batch).await {
178                        eprintln!("Failed to flush batch: {}", e);
179                    }
180                }
181            }
182        });
183        
184        Ok(())
185    }
186
187    /// Stop background worker gracefully
188    pub async fn shutdown(&self) -> Result<()> {
189        let mut running = self.running.lock().await;
190        *running = false;
191        drop(running);
192        
193        // Final flush
194        self.flush().await?;
195        Ok(())
196    }
197
198    /// Fastest bulk insert using COPY
199    async fn bulk_insert(pool: &Pool, messages: Vec<Message>) -> Result<()> {
200        if messages.is_empty() {
201            return Ok(());
202        }
203        
204        let client = pool.get().await?;
205        
206        // Use COPY for maximum performance
207        let copy_stmt = "COPY messages (time, user_id, content) FROM STDIN (FORMAT CSV, DELIMITER ',')";
208        let copy_writer = client.copy_in(copy_stmt).await?;
209        
210        // Pin the writer immediately so SinkExt methods can safely be called on it
211        tokio::pin!(copy_writer);
212        
213        // Pre-allocate buffer for performance
214        let mut batch_buffer = String::with_capacity(messages.len() * 256);
215        
216        for msg in &messages {
217            batch_buffer.push_str(&msg.time.to_string());
218            batch_buffer.push(',');
219            batch_buffer.push_str(&msg.id);
220            batch_buffer.push_str(",\"");
221            
222            // Allocation-free CSV escaping
223            for c in msg.content.chars() {
224                if c == '"' {
225                    batch_buffer.push_str("\"\"");
226                } else {
227                    batch_buffer.push(c);
228                }
229            }
230            batch_buffer.push_str("\"\n");
231        }
232        
233        // Now `.send()` works flawlessly because `copy_writer` is pinned to the stack
234        copy_writer.as_mut().send(bytes::Bytes::from(batch_buffer)).await?;
235        copy_writer.finish().await?;
236        
237        Ok(())
238    }
239
240    async fn flush_batch(&self, messages: Vec<Message>) -> Result<()> {
241        if messages.is_empty() {
242            return Ok(());
243        }
244        Self::bulk_insert(&self.pool, messages).await
245    }
246
247    /// Get current buffer size
248    pub async fn buffer_size(&self) -> usize {
249        self.buffer.lock().await.len()
250    }
251}