1use parking_lot::Mutex;
27use std::collections::VecDeque;
28use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
29use std::time::{Duration, Instant};
30use tokio::sync::{broadcast, oneshot};
31
32use super::{ProxyError, Result};
33
34#[derive(Debug, Clone)]
36pub struct BufferConfig {
37 pub buffer_timeout: Duration,
39 pub max_buffered_queries: usize,
41 pub max_buffer_memory: usize,
43 pub allow_queries_during_drain: bool,
45}
46
47impl Default for BufferConfig {
48 fn default() -> Self {
49 Self {
50 buffer_timeout: Duration::from_secs(5),
51 max_buffered_queries: 10000,
52 max_buffer_memory: 100 * 1024 * 1024, allow_queries_during_drain: true,
54 }
55 }
56}
57
58#[derive(Debug)]
60pub struct BufferedQuery {
61 pub sql: String,
63 pub params: Vec<Vec<u8>>,
65 pub buffered_at: Instant,
67 pub response_tx: oneshot::Sender<BufferResult>,
69 pub client_id: u64,
71}
72
73#[derive(Debug)]
75pub enum BufferResult {
76 Success,
78 Error(String),
80 Timeout,
82 SwitchoverFailed,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum BufferState {
89 Passthrough,
91 Buffering,
93 Draining,
95}
96
97pub struct SwitchoverBuffer {
99 config: BufferConfig,
101 state: AtomicU64, is_buffering: AtomicBool,
105 buffer: Mutex<VecDeque<BufferedQuery>>,
107 buffer_memory: AtomicU64,
109 buffering_started: Mutex<Option<Instant>>,
111 stats: BufferStats,
113 state_tx: broadcast::Sender<BufferState>,
115}
116
117impl SwitchoverBuffer {
118 pub fn new(config: BufferConfig) -> Self {
120 let (state_tx, _) = broadcast::channel(16);
121
122 Self {
123 config,
124 state: AtomicU64::new(BufferState::Passthrough as u64),
125 is_buffering: AtomicBool::new(false),
126 buffer: Mutex::new(VecDeque::new()),
127 buffer_memory: AtomicU64::new(0),
128 buffering_started: Mutex::new(None),
129 stats: BufferStats::default(),
130 state_tx,
131 }
132 }
133
134 pub fn is_buffering(&self) -> bool {
136 self.is_buffering.load(Ordering::SeqCst)
137 }
138
139 pub fn state(&self) -> BufferState {
141 match self.state.load(Ordering::SeqCst) {
142 0 => BufferState::Passthrough,
143 1 => BufferState::Buffering,
144 2 => BufferState::Draining,
145 _ => BufferState::Passthrough,
146 }
147 }
148
149 pub fn subscribe(&self) -> broadcast::Receiver<BufferState> {
151 self.state_tx.subscribe()
152 }
153
154 pub fn start_buffering(&self) {
156 self.is_buffering.store(true, Ordering::SeqCst);
157 self.state
158 .store(BufferState::Buffering as u64, Ordering::SeqCst);
159 *self.buffering_started.lock() = Some(Instant::now());
160
161 self.stats
162 .buffering_sessions
163 .fetch_add(1, Ordering::Relaxed);
164
165 let _ = self.state_tx.send(BufferState::Buffering);
166
167 tracing::info!("Switchover buffer: started buffering");
168 }
169
170 pub fn stop_buffering(&self) {
172 self.is_buffering.store(false, Ordering::SeqCst);
173 self.state
174 .store(BufferState::Draining as u64, Ordering::SeqCst);
175
176 let duration = self
177 .buffering_started
178 .lock()
179 .map(|start| start.elapsed())
180 .unwrap_or_default();
181
182 let _ = self.state_tx.send(BufferState::Draining);
183
184 tracing::info!(
185 "Switchover buffer: stopped buffering after {:?}, {} queries buffered",
186 duration,
187 self.buffer.lock().len()
188 );
189 }
190
191 pub fn buffer_query(
193 &self,
194 sql: String,
195 params: Vec<Vec<u8>>,
196 client_id: u64,
197 ) -> Result<oneshot::Receiver<BufferResult>> {
198 if !self.is_buffering() {
200 return Err(ProxyError::Internal("Not in buffering mode".to_string()));
201 }
202
203 if let Some(started) = *self.buffering_started.lock() {
205 if started.elapsed() > self.config.buffer_timeout {
206 return Err(ProxyError::Timeout("Buffer timeout exceeded".to_string()));
207 }
208 }
209
210 let buffer_len = self.buffer.lock().len();
212 if buffer_len >= self.config.max_buffered_queries {
213 self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
214 return Err(ProxyError::PoolExhausted("Buffer full".to_string()));
215 }
216
217 let query_size = sql.len() + params.iter().map(|p| p.len()).sum::<usize>();
219 let current_memory = self.buffer_memory.load(Ordering::Relaxed) as usize;
220 if current_memory + query_size > self.config.max_buffer_memory {
221 self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
222 return Err(ProxyError::PoolExhausted(
223 "Buffer memory exhausted".to_string(),
224 ));
225 }
226
227 let (response_tx, response_rx) = oneshot::channel();
229
230 let buffered = BufferedQuery {
232 sql,
233 params,
234 buffered_at: Instant::now(),
235 response_tx,
236 client_id,
237 };
238
239 self.buffer.lock().push_back(buffered);
241 self.buffer_memory
242 .fetch_add(query_size as u64, Ordering::Relaxed);
243 self.stats.buffered_queries.fetch_add(1, Ordering::Relaxed);
244
245 Ok(response_rx)
246 }
247
248 pub async fn drain<F, Fut>(&self, execute_fn: F)
253 where
254 F: Fn(String, Vec<Vec<u8>>) -> Fut,
255 Fut: std::future::Future<Output = Result<()>>,
256 {
257 tracing::info!("Switchover buffer: draining buffer");
258
259 let queries: Vec<BufferedQuery> = {
260 let mut buffer = self.buffer.lock();
261 buffer.drain(..).collect()
262 };
263
264 self.buffer_memory.store(0, Ordering::Relaxed);
265
266 let total = queries.len();
267 let mut success = 0;
268 let mut failed = 0;
269 let mut timed_out = 0;
270
271 for query in queries {
272 if query.buffered_at.elapsed() > self.config.buffer_timeout {
274 let _ = query.response_tx.send(BufferResult::Timeout);
275 timed_out += 1;
276 continue;
277 }
278
279 match execute_fn(query.sql, query.params).await {
281 Ok(()) => {
282 let _ = query.response_tx.send(BufferResult::Success);
283 success += 1;
284 }
285 Err(e) => {
286 let _ = query.response_tx.send(BufferResult::Error(e.to_string()));
287 failed += 1;
288 }
289 }
290 }
291
292 self.stats
293 .replayed_queries
294 .fetch_add(success, Ordering::Relaxed);
295 self.stats
296 .failed_replays
297 .fetch_add(failed, Ordering::Relaxed);
298 self.stats
299 .timed_out_queries
300 .fetch_add(timed_out, Ordering::Relaxed);
301
302 self.state
304 .store(BufferState::Passthrough as u64, Ordering::SeqCst);
305 let _ = self.state_tx.send(BufferState::Passthrough);
306
307 tracing::info!(
308 "Switchover buffer: drained {} queries (success: {}, failed: {}, timeout: {})",
309 total,
310 success,
311 failed,
312 timed_out
313 );
314 }
315
316 pub fn fail_all(&self, error: &str) {
318 let queries: Vec<BufferedQuery> = {
319 let mut buffer = self.buffer.lock();
320 buffer.drain(..).collect()
321 };
322
323 let query_count = queries.len();
324 self.buffer_memory.store(0, Ordering::Relaxed);
325
326 for query in queries {
327 let _ = query.response_tx.send(BufferResult::SwitchoverFailed);
328 }
329
330 self.stats
331 .failed_replays
332 .fetch_add(query_count as u64, Ordering::Relaxed);
333
334 self.state
336 .store(BufferState::Passthrough as u64, Ordering::SeqCst);
337 let _ = self.state_tx.send(BufferState::Passthrough);
338
339 tracing::warn!(
340 "Switchover buffer: failed {} queries due to: {}",
341 query_count,
342 error
343 );
344 }
345
346 pub fn len(&self) -> usize {
348 self.buffer.lock().len()
349 }
350
351 pub fn is_empty(&self) -> bool {
353 self.buffer.lock().is_empty()
354 }
355
356 pub fn stats(&self) -> BufferStatsSnapshot {
358 BufferStatsSnapshot {
359 buffering_sessions: self.stats.buffering_sessions.load(Ordering::Relaxed),
360 buffered_queries: self.stats.buffered_queries.load(Ordering::Relaxed),
361 replayed_queries: self.stats.replayed_queries.load(Ordering::Relaxed),
362 failed_replays: self.stats.failed_replays.load(Ordering::Relaxed),
363 timed_out_queries: self.stats.timed_out_queries.load(Ordering::Relaxed),
364 rejected_queries: self.stats.rejected_queries.load(Ordering::Relaxed),
365 current_buffer_size: self.buffer.lock().len(),
366 current_memory_usage: self.buffer_memory.load(Ordering::Relaxed) as usize,
367 }
368 }
369}
370
371#[derive(Default)]
373struct BufferStats {
374 buffering_sessions: AtomicU64,
375 buffered_queries: AtomicU64,
376 replayed_queries: AtomicU64,
377 failed_replays: AtomicU64,
378 timed_out_queries: AtomicU64,
379 rejected_queries: AtomicU64,
380}
381
382#[derive(Debug, Clone)]
384pub struct BufferStatsSnapshot {
385 pub buffering_sessions: u64,
386 pub buffered_queries: u64,
387 pub replayed_queries: u64,
388 pub failed_replays: u64,
389 pub timed_out_queries: u64,
390 pub rejected_queries: u64,
391 pub current_buffer_size: usize,
392 pub current_memory_usage: usize,
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_buffer_state_transitions() {
401 let buffer = SwitchoverBuffer::new(BufferConfig::default());
402
403 assert_eq!(buffer.state(), BufferState::Passthrough);
404 assert!(!buffer.is_buffering());
405
406 buffer.start_buffering();
407 assert_eq!(buffer.state(), BufferState::Buffering);
408 assert!(buffer.is_buffering());
409
410 buffer.stop_buffering();
411 assert_eq!(buffer.state(), BufferState::Draining);
412 assert!(!buffer.is_buffering());
413 }
414
415 #[tokio::test]
416 async fn test_buffer_query() {
417 let buffer = SwitchoverBuffer::new(BufferConfig::default());
418
419 let result = buffer.buffer_query("SELECT 1".to_string(), vec![], 1);
421 assert!(result.is_err());
422
423 buffer.start_buffering();
425
426 let rx = buffer
428 .buffer_query("INSERT INTO t VALUES (1)".to_string(), vec![], 1)
429 .unwrap();
430 assert_eq!(buffer.len(), 1);
431
432 buffer.drain(|_sql, _params| async { Ok(()) }).await;
434
435 let result = rx.await.unwrap();
437 assert!(matches!(result, BufferResult::Success));
438 assert!(buffer.is_empty());
439 }
440
441 #[test]
442 fn test_buffer_limits() {
443 let config = BufferConfig {
444 max_buffered_queries: 2,
445 ..Default::default()
446 };
447 let buffer = SwitchoverBuffer::new(config);
448 buffer.start_buffering();
449
450 let _ = buffer.buffer_query("Q1".to_string(), vec![], 1).unwrap();
452 let _ = buffer.buffer_query("Q2".to_string(), vec![], 2).unwrap();
453
454 let result = buffer.buffer_query("Q3".to_string(), vec![], 3);
456 assert!(result.is_err());
457 }
458}