use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use tokio::sync::{broadcast, oneshot};
use super::{Result, ProxyError};
#[derive(Debug, Clone)]
pub struct BufferConfig {
pub buffer_timeout: Duration,
pub max_buffered_queries: usize,
pub max_buffer_memory: usize,
pub allow_queries_during_drain: bool,
}
impl Default for BufferConfig {
fn default() -> Self {
Self {
buffer_timeout: Duration::from_secs(5),
max_buffered_queries: 10000,
max_buffer_memory: 100 * 1024 * 1024, allow_queries_during_drain: true,
}
}
}
#[derive(Debug)]
pub struct BufferedQuery {
pub sql: String,
pub params: Vec<Vec<u8>>,
pub buffered_at: Instant,
pub response_tx: oneshot::Sender<BufferResult>,
pub client_id: u64,
}
#[derive(Debug)]
pub enum BufferResult {
Success,
Error(String),
Timeout,
SwitchoverFailed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BufferState {
Passthrough,
Buffering,
Draining,
}
pub struct SwitchoverBuffer {
config: BufferConfig,
state: AtomicU64, is_buffering: AtomicBool,
buffer: Mutex<VecDeque<BufferedQuery>>,
buffer_memory: AtomicU64,
buffering_started: Mutex<Option<Instant>>,
stats: BufferStats,
state_tx: broadcast::Sender<BufferState>,
}
impl SwitchoverBuffer {
pub fn new(config: BufferConfig) -> Self {
let (state_tx, _) = broadcast::channel(16);
Self {
config,
state: AtomicU64::new(BufferState::Passthrough as u64),
is_buffering: AtomicBool::new(false),
buffer: Mutex::new(VecDeque::new()),
buffer_memory: AtomicU64::new(0),
buffering_started: Mutex::new(None),
stats: BufferStats::default(),
state_tx,
}
}
pub fn is_buffering(&self) -> bool {
self.is_buffering.load(Ordering::SeqCst)
}
pub fn state(&self) -> BufferState {
match self.state.load(Ordering::SeqCst) {
0 => BufferState::Passthrough,
1 => BufferState::Buffering,
2 => BufferState::Draining,
_ => BufferState::Passthrough,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<BufferState> {
self.state_tx.subscribe()
}
pub fn start_buffering(&self) {
self.is_buffering.store(true, Ordering::SeqCst);
self.state.store(BufferState::Buffering as u64, Ordering::SeqCst);
*self.buffering_started.lock() = Some(Instant::now());
self.stats.buffering_sessions.fetch_add(1, Ordering::Relaxed);
let _ = self.state_tx.send(BufferState::Buffering);
tracing::info!("Switchover buffer: started buffering");
}
pub fn stop_buffering(&self) {
self.is_buffering.store(false, Ordering::SeqCst);
self.state.store(BufferState::Draining as u64, Ordering::SeqCst);
let duration = self.buffering_started.lock()
.map(|start| start.elapsed())
.unwrap_or_default();
let _ = self.state_tx.send(BufferState::Draining);
tracing::info!(
"Switchover buffer: stopped buffering after {:?}, {} queries buffered",
duration,
self.buffer.lock().len()
);
}
pub fn buffer_query(
&self,
sql: String,
params: Vec<Vec<u8>>,
client_id: u64,
) -> Result<oneshot::Receiver<BufferResult>> {
if !self.is_buffering() {
return Err(ProxyError::Internal("Not in buffering mode".to_string()));
}
if let Some(started) = *self.buffering_started.lock() {
if started.elapsed() > self.config.buffer_timeout {
return Err(ProxyError::Timeout("Buffer timeout exceeded".to_string()));
}
}
let buffer_len = self.buffer.lock().len();
if buffer_len >= self.config.max_buffered_queries {
self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
return Err(ProxyError::PoolExhausted("Buffer full".to_string()));
}
let query_size = sql.len() + params.iter().map(|p| p.len()).sum::<usize>();
let current_memory = self.buffer_memory.load(Ordering::Relaxed) as usize;
if current_memory + query_size > self.config.max_buffer_memory {
self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
return Err(ProxyError::PoolExhausted("Buffer memory exhausted".to_string()));
}
let (response_tx, response_rx) = oneshot::channel();
let buffered = BufferedQuery {
sql,
params,
buffered_at: Instant::now(),
response_tx,
client_id,
};
self.buffer.lock().push_back(buffered);
self.buffer_memory.fetch_add(query_size as u64, Ordering::Relaxed);
self.stats.buffered_queries.fetch_add(1, Ordering::Relaxed);
Ok(response_rx)
}
pub async fn drain<F, Fut>(&self, execute_fn: F)
where
F: Fn(String, Vec<Vec<u8>>) -> Fut,
Fut: std::future::Future<Output = Result<()>>,
{
tracing::info!("Switchover buffer: draining buffer");
let queries: Vec<BufferedQuery> = {
let mut buffer = self.buffer.lock();
buffer.drain(..).collect()
};
self.buffer_memory.store(0, Ordering::Relaxed);
let total = queries.len();
let mut success = 0;
let mut failed = 0;
let mut timed_out = 0;
for query in queries {
if query.buffered_at.elapsed() > self.config.buffer_timeout {
let _ = query.response_tx.send(BufferResult::Timeout);
timed_out += 1;
continue;
}
match execute_fn(query.sql, query.params).await {
Ok(()) => {
let _ = query.response_tx.send(BufferResult::Success);
success += 1;
}
Err(e) => {
let _ = query.response_tx.send(BufferResult::Error(e.to_string()));
failed += 1;
}
}
}
self.stats.replayed_queries.fetch_add(success, Ordering::Relaxed);
self.stats.failed_replays.fetch_add(failed, Ordering::Relaxed);
self.stats.timed_out_queries.fetch_add(timed_out, Ordering::Relaxed);
self.state.store(BufferState::Passthrough as u64, Ordering::SeqCst);
let _ = self.state_tx.send(BufferState::Passthrough);
tracing::info!(
"Switchover buffer: drained {} queries (success: {}, failed: {}, timeout: {})",
total,
success,
failed,
timed_out
);
}
pub fn fail_all(&self, error: &str) {
let queries: Vec<BufferedQuery> = {
let mut buffer = self.buffer.lock();
buffer.drain(..).collect()
};
let query_count = queries.len();
self.buffer_memory.store(0, Ordering::Relaxed);
for query in queries {
let _ = query.response_tx.send(BufferResult::SwitchoverFailed);
}
self.stats.failed_replays.fetch_add(query_count as u64, Ordering::Relaxed);
self.state.store(BufferState::Passthrough as u64, Ordering::SeqCst);
let _ = self.state_tx.send(BufferState::Passthrough);
tracing::warn!(
"Switchover buffer: failed {} queries due to: {}",
query_count,
error
);
}
pub fn len(&self) -> usize {
self.buffer.lock().len()
}
pub fn is_empty(&self) -> bool {
self.buffer.lock().is_empty()
}
pub fn stats(&self) -> BufferStatsSnapshot {
BufferStatsSnapshot {
buffering_sessions: self.stats.buffering_sessions.load(Ordering::Relaxed),
buffered_queries: self.stats.buffered_queries.load(Ordering::Relaxed),
replayed_queries: self.stats.replayed_queries.load(Ordering::Relaxed),
failed_replays: self.stats.failed_replays.load(Ordering::Relaxed),
timed_out_queries: self.stats.timed_out_queries.load(Ordering::Relaxed),
rejected_queries: self.stats.rejected_queries.load(Ordering::Relaxed),
current_buffer_size: self.buffer.lock().len(),
current_memory_usage: self.buffer_memory.load(Ordering::Relaxed) as usize,
}
}
}
#[derive(Default)]
struct BufferStats {
buffering_sessions: AtomicU64,
buffered_queries: AtomicU64,
replayed_queries: AtomicU64,
failed_replays: AtomicU64,
timed_out_queries: AtomicU64,
rejected_queries: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct BufferStatsSnapshot {
pub buffering_sessions: u64,
pub buffered_queries: u64,
pub replayed_queries: u64,
pub failed_replays: u64,
pub timed_out_queries: u64,
pub rejected_queries: u64,
pub current_buffer_size: usize,
pub current_memory_usage: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_state_transitions() {
let buffer = SwitchoverBuffer::new(BufferConfig::default());
assert_eq!(buffer.state(), BufferState::Passthrough);
assert!(!buffer.is_buffering());
buffer.start_buffering();
assert_eq!(buffer.state(), BufferState::Buffering);
assert!(buffer.is_buffering());
buffer.stop_buffering();
assert_eq!(buffer.state(), BufferState::Draining);
assert!(!buffer.is_buffering());
}
#[tokio::test]
async fn test_buffer_query() {
let buffer = SwitchoverBuffer::new(BufferConfig::default());
let result = buffer.buffer_query("SELECT 1".to_string(), vec![], 1);
assert!(result.is_err());
buffer.start_buffering();
let rx = buffer.buffer_query("INSERT INTO t VALUES (1)".to_string(), vec![], 1).unwrap();
assert_eq!(buffer.len(), 1);
buffer.drain(|_sql, _params| async { Ok(()) }).await;
let result = rx.await.unwrap();
assert!(matches!(result, BufferResult::Success));
assert!(buffer.is_empty());
}
#[test]
fn test_buffer_limits() {
let config = BufferConfig {
max_buffered_queries: 2,
..Default::default()
};
let buffer = SwitchoverBuffer::new(config);
buffer.start_buffering();
let _ = buffer.buffer_query("Q1".to_string(), vec![], 1).unwrap();
let _ = buffer.buffer_query("Q2".to_string(), vec![], 2).unwrap();
let result = buffer.buffer_query("Q3".to_string(), vec![], 3);
assert!(result.is_err());
}
}