1use std::collections::VecDeque;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::time::{Duration, Instant};
29use parking_lot::Mutex;
30use tokio::sync::{broadcast, oneshot};
31
32use super::{Result, ProxyError};
33
34#[derive(Debug, Clone)]
36pub struct BufferConfig {
37 pub buffer_timeout: Duration,
39 pub max_buffered_queries: usize,
41 pub max_buffer_memory: usize,
43 pub allow_queries_during_drain: bool,
45}
46
47impl Default for BufferConfig {
48 fn default() -> Self {
49 Self {
50 buffer_timeout: Duration::from_secs(5),
51 max_buffered_queries: 10000,
52 max_buffer_memory: 100 * 1024 * 1024, allow_queries_during_drain: true,
54 }
55 }
56}
57
58#[derive(Debug)]
60pub struct BufferedQuery {
61 pub sql: String,
63 pub params: Vec<Vec<u8>>,
65 pub buffered_at: Instant,
67 pub response_tx: oneshot::Sender<BufferResult>,
69 pub client_id: u64,
71}
72
73#[derive(Debug)]
75pub enum BufferResult {
76 Success,
78 Error(String),
80 Timeout,
82 SwitchoverFailed,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum BufferState {
89 Passthrough,
91 Buffering,
93 Draining,
95}
96
97pub struct SwitchoverBuffer {
99 config: BufferConfig,
101 state: AtomicU64, is_buffering: AtomicBool,
105 buffer: Mutex<VecDeque<BufferedQuery>>,
107 buffer_memory: AtomicU64,
109 buffering_started: Mutex<Option<Instant>>,
111 stats: BufferStats,
113 state_tx: broadcast::Sender<BufferState>,
115}
116
117impl SwitchoverBuffer {
118 pub fn new(config: BufferConfig) -> Self {
120 let (state_tx, _) = broadcast::channel(16);
121
122 Self {
123 config,
124 state: AtomicU64::new(BufferState::Passthrough as u64),
125 is_buffering: AtomicBool::new(false),
126 buffer: Mutex::new(VecDeque::new()),
127 buffer_memory: AtomicU64::new(0),
128 buffering_started: Mutex::new(None),
129 stats: BufferStats::default(),
130 state_tx,
131 }
132 }
133
134 pub fn is_buffering(&self) -> bool {
136 self.is_buffering.load(Ordering::SeqCst)
137 }
138
139 pub fn state(&self) -> BufferState {
141 match self.state.load(Ordering::SeqCst) {
142 0 => BufferState::Passthrough,
143 1 => BufferState::Buffering,
144 2 => BufferState::Draining,
145 _ => BufferState::Passthrough,
146 }
147 }
148
149 pub fn subscribe(&self) -> broadcast::Receiver<BufferState> {
151 self.state_tx.subscribe()
152 }
153
154 pub fn start_buffering(&self) {
156 self.is_buffering.store(true, Ordering::SeqCst);
157 self.state.store(BufferState::Buffering as u64, Ordering::SeqCst);
158 *self.buffering_started.lock() = Some(Instant::now());
159
160 self.stats.buffering_sessions.fetch_add(1, Ordering::Relaxed);
161
162 let _ = self.state_tx.send(BufferState::Buffering);
163
164 tracing::info!("Switchover buffer: started buffering");
165 }
166
167 pub fn stop_buffering(&self) {
169 self.is_buffering.store(false, Ordering::SeqCst);
170 self.state.store(BufferState::Draining as u64, Ordering::SeqCst);
171
172 let duration = self.buffering_started.lock()
173 .map(|start| start.elapsed())
174 .unwrap_or_default();
175
176 let _ = self.state_tx.send(BufferState::Draining);
177
178 tracing::info!(
179 "Switchover buffer: stopped buffering after {:?}, {} queries buffered",
180 duration,
181 self.buffer.lock().len()
182 );
183 }
184
185 pub fn buffer_query(
187 &self,
188 sql: String,
189 params: Vec<Vec<u8>>,
190 client_id: u64,
191 ) -> Result<oneshot::Receiver<BufferResult>> {
192 if !self.is_buffering() {
194 return Err(ProxyError::Internal("Not in buffering mode".to_string()));
195 }
196
197 if let Some(started) = *self.buffering_started.lock() {
199 if started.elapsed() > self.config.buffer_timeout {
200 return Err(ProxyError::Timeout("Buffer timeout exceeded".to_string()));
201 }
202 }
203
204 let buffer_len = self.buffer.lock().len();
206 if buffer_len >= self.config.max_buffered_queries {
207 self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
208 return Err(ProxyError::PoolExhausted("Buffer full".to_string()));
209 }
210
211 let query_size = sql.len() + params.iter().map(|p| p.len()).sum::<usize>();
213 let current_memory = self.buffer_memory.load(Ordering::Relaxed) as usize;
214 if current_memory + query_size > self.config.max_buffer_memory {
215 self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
216 return Err(ProxyError::PoolExhausted("Buffer memory exhausted".to_string()));
217 }
218
219 let (response_tx, response_rx) = oneshot::channel();
221
222 let buffered = BufferedQuery {
224 sql,
225 params,
226 buffered_at: Instant::now(),
227 response_tx,
228 client_id,
229 };
230
231 self.buffer.lock().push_back(buffered);
233 self.buffer_memory.fetch_add(query_size as u64, Ordering::Relaxed);
234 self.stats.buffered_queries.fetch_add(1, Ordering::Relaxed);
235
236 Ok(response_rx)
237 }
238
239 pub async fn drain<F, Fut>(&self, execute_fn: F)
244 where
245 F: Fn(String, Vec<Vec<u8>>) -> Fut,
246 Fut: std::future::Future<Output = Result<()>>,
247 {
248 tracing::info!("Switchover buffer: draining buffer");
249
250 let queries: Vec<BufferedQuery> = {
251 let mut buffer = self.buffer.lock();
252 buffer.drain(..).collect()
253 };
254
255 self.buffer_memory.store(0, Ordering::Relaxed);
256
257 let total = queries.len();
258 let mut success = 0;
259 let mut failed = 0;
260 let mut timed_out = 0;
261
262 for query in queries {
263 if query.buffered_at.elapsed() > self.config.buffer_timeout {
265 let _ = query.response_tx.send(BufferResult::Timeout);
266 timed_out += 1;
267 continue;
268 }
269
270 match execute_fn(query.sql, query.params).await {
272 Ok(()) => {
273 let _ = query.response_tx.send(BufferResult::Success);
274 success += 1;
275 }
276 Err(e) => {
277 let _ = query.response_tx.send(BufferResult::Error(e.to_string()));
278 failed += 1;
279 }
280 }
281 }
282
283 self.stats.replayed_queries.fetch_add(success, Ordering::Relaxed);
284 self.stats.failed_replays.fetch_add(failed, Ordering::Relaxed);
285 self.stats.timed_out_queries.fetch_add(timed_out, Ordering::Relaxed);
286
287 self.state.store(BufferState::Passthrough as u64, Ordering::SeqCst);
289 let _ = self.state_tx.send(BufferState::Passthrough);
290
291 tracing::info!(
292 "Switchover buffer: drained {} queries (success: {}, failed: {}, timeout: {})",
293 total,
294 success,
295 failed,
296 timed_out
297 );
298 }
299
300 pub fn fail_all(&self, error: &str) {
302 let queries: Vec<BufferedQuery> = {
303 let mut buffer = self.buffer.lock();
304 buffer.drain(..).collect()
305 };
306
307 let query_count = queries.len();
308 self.buffer_memory.store(0, Ordering::Relaxed);
309
310 for query in queries {
311 let _ = query.response_tx.send(BufferResult::SwitchoverFailed);
312 }
313
314 self.stats.failed_replays.fetch_add(query_count as u64, Ordering::Relaxed);
315
316 self.state.store(BufferState::Passthrough as u64, Ordering::SeqCst);
318 let _ = self.state_tx.send(BufferState::Passthrough);
319
320 tracing::warn!(
321 "Switchover buffer: failed {} queries due to: {}",
322 query_count,
323 error
324 );
325 }
326
327 pub fn len(&self) -> usize {
329 self.buffer.lock().len()
330 }
331
332 pub fn is_empty(&self) -> bool {
334 self.buffer.lock().is_empty()
335 }
336
337 pub fn stats(&self) -> BufferStatsSnapshot {
339 BufferStatsSnapshot {
340 buffering_sessions: self.stats.buffering_sessions.load(Ordering::Relaxed),
341 buffered_queries: self.stats.buffered_queries.load(Ordering::Relaxed),
342 replayed_queries: self.stats.replayed_queries.load(Ordering::Relaxed),
343 failed_replays: self.stats.failed_replays.load(Ordering::Relaxed),
344 timed_out_queries: self.stats.timed_out_queries.load(Ordering::Relaxed),
345 rejected_queries: self.stats.rejected_queries.load(Ordering::Relaxed),
346 current_buffer_size: self.buffer.lock().len(),
347 current_memory_usage: self.buffer_memory.load(Ordering::Relaxed) as usize,
348 }
349 }
350}
351
352#[derive(Default)]
354struct BufferStats {
355 buffering_sessions: AtomicU64,
356 buffered_queries: AtomicU64,
357 replayed_queries: AtomicU64,
358 failed_replays: AtomicU64,
359 timed_out_queries: AtomicU64,
360 rejected_queries: AtomicU64,
361}
362
363#[derive(Debug, Clone)]
365pub struct BufferStatsSnapshot {
366 pub buffering_sessions: u64,
367 pub buffered_queries: u64,
368 pub replayed_queries: u64,
369 pub failed_replays: u64,
370 pub timed_out_queries: u64,
371 pub rejected_queries: u64,
372 pub current_buffer_size: usize,
373 pub current_memory_usage: usize,
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_buffer_state_transitions() {
382 let buffer = SwitchoverBuffer::new(BufferConfig::default());
383
384 assert_eq!(buffer.state(), BufferState::Passthrough);
385 assert!(!buffer.is_buffering());
386
387 buffer.start_buffering();
388 assert_eq!(buffer.state(), BufferState::Buffering);
389 assert!(buffer.is_buffering());
390
391 buffer.stop_buffering();
392 assert_eq!(buffer.state(), BufferState::Draining);
393 assert!(!buffer.is_buffering());
394 }
395
396 #[tokio::test]
397 async fn test_buffer_query() {
398 let buffer = SwitchoverBuffer::new(BufferConfig::default());
399
400 let result = buffer.buffer_query("SELECT 1".to_string(), vec![], 1);
402 assert!(result.is_err());
403
404 buffer.start_buffering();
406
407 let rx = buffer.buffer_query("INSERT INTO t VALUES (1)".to_string(), vec![], 1).unwrap();
409 assert_eq!(buffer.len(), 1);
410
411 buffer.drain(|_sql, _params| async { Ok(()) }).await;
413
414 let result = rx.await.unwrap();
416 assert!(matches!(result, BufferResult::Success));
417 assert!(buffer.is_empty());
418 }
419
420 #[test]
421 fn test_buffer_limits() {
422 let config = BufferConfig {
423 max_buffered_queries: 2,
424 ..Default::default()
425 };
426 let buffer = SwitchoverBuffer::new(config);
427 buffer.start_buffering();
428
429 let _ = buffer.buffer_query("Q1".to_string(), vec![], 1).unwrap();
431 let _ = buffer.buffer_query("Q2".to_string(), vec![], 2).unwrap();
432
433 let result = buffer.buffer_query("Q3".to_string(), vec![], 3);
435 assert!(result.is_err());
436 }
437}