Skip to main content

heliosdb_proxy/
switchover_buffer.rs

1//! Switchover Buffer - Query buffering during controlled switchover
2//!
3//! Buffers write queries during the brief switchover window to ensure
4//! zero transaction loss. Queries are replayed to the new primary
5//! once switchover completes.
6//!
7//! ## How it works
8//!
9//! ```text
10//! Normal Operation:
11//!   Client → Proxy → Primary
12//!
13//! During Switchover:
14//!   Client → Proxy → Buffer (queued)
15//!                      ↓
16//!   [Switchover completes]
17//!                      ↓
18//!            Buffer → New Primary (replayed)
19//! ```
20//!
21//! ## Timeout Behavior
22//!
23//! If switchover takes longer than `buffer_timeout`, buffered queries
24//! will fail with a timeout error rather than blocking indefinitely.
25
26use std::collections::VecDeque;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::time::{Duration, Instant};
29use parking_lot::Mutex;
30use tokio::sync::{broadcast, oneshot};
31
32use super::{Result, ProxyError};
33
34/// Buffer configuration
35#[derive(Debug, Clone)]
36pub struct BufferConfig {
37    /// Maximum time to buffer queries (default: 5s)
38    pub buffer_timeout: Duration,
39    /// Maximum number of queries to buffer (default: 10000)
40    pub max_buffered_queries: usize,
41    /// Maximum memory for buffered queries (default: 100MB)
42    pub max_buffer_memory: usize,
43    /// Whether to allow new queries while draining buffer
44    pub allow_queries_during_drain: bool,
45}
46
47impl Default for BufferConfig {
48    fn default() -> Self {
49        Self {
50            buffer_timeout: Duration::from_secs(5),
51            max_buffered_queries: 10000,
52            max_buffer_memory: 100 * 1024 * 1024, // 100MB
53            allow_queries_during_drain: true,
54        }
55    }
56}
57
58/// A buffered query waiting to be executed
59#[derive(Debug)]
60pub struct BufferedQuery {
61    /// SQL statement
62    pub sql: String,
63    /// Query parameters
64    pub params: Vec<Vec<u8>>,
65    /// Time when query was buffered
66    pub buffered_at: Instant,
67    /// Channel to send result back to client
68    pub response_tx: oneshot::Sender<BufferResult>,
69    /// Client identifier (for logging/debugging)
70    pub client_id: u64,
71}
72
73/// Result of a buffered query after replay
74#[derive(Debug)]
75pub enum BufferResult {
76    /// Query executed successfully
77    Success,
78    /// Query failed with error
79    Error(String),
80    /// Query timed out while buffered
81    Timeout,
82    /// Switchover was cancelled/failed
83    SwitchoverFailed,
84}
85
86/// Buffer state
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum BufferState {
89    /// Normal operation - no buffering
90    Passthrough,
91    /// Buffering writes (during switchover)
92    Buffering,
93    /// Draining buffer to new primary
94    Draining,
95}
96
97/// Switchover buffer for zero-downtime primary transitions
98pub struct SwitchoverBuffer {
99    /// Configuration
100    config: BufferConfig,
101    /// Current state
102    state: AtomicU64, // BufferState as u64
103    /// Is buffering active
104    is_buffering: AtomicBool,
105    /// Buffered queries
106    buffer: Mutex<VecDeque<BufferedQuery>>,
107    /// Current buffer memory usage
108    buffer_memory: AtomicU64,
109    /// Time when buffering started
110    buffering_started: Mutex<Option<Instant>>,
111    /// Statistics
112    stats: BufferStats,
113    /// State change broadcaster
114    state_tx: broadcast::Sender<BufferState>,
115}
116
117impl SwitchoverBuffer {
118    /// Create a new switchover buffer
119    pub fn new(config: BufferConfig) -> Self {
120        let (state_tx, _) = broadcast::channel(16);
121
122        Self {
123            config,
124            state: AtomicU64::new(BufferState::Passthrough as u64),
125            is_buffering: AtomicBool::new(false),
126            buffer: Mutex::new(VecDeque::new()),
127            buffer_memory: AtomicU64::new(0),
128            buffering_started: Mutex::new(None),
129            stats: BufferStats::default(),
130            state_tx,
131        }
132    }
133
134    /// Check if currently buffering
135    pub fn is_buffering(&self) -> bool {
136        self.is_buffering.load(Ordering::SeqCst)
137    }
138
139    /// Get current state
140    pub fn state(&self) -> BufferState {
141        match self.state.load(Ordering::SeqCst) {
142            0 => BufferState::Passthrough,
143            1 => BufferState::Buffering,
144            2 => BufferState::Draining,
145            _ => BufferState::Passthrough,
146        }
147    }
148
149    /// Subscribe to state changes
150    pub fn subscribe(&self) -> broadcast::Receiver<BufferState> {
151        self.state_tx.subscribe()
152    }
153
154    /// Start buffering (called when switchover begins)
155    pub fn start_buffering(&self) {
156        self.is_buffering.store(true, Ordering::SeqCst);
157        self.state.store(BufferState::Buffering as u64, Ordering::SeqCst);
158        *self.buffering_started.lock() = Some(Instant::now());
159
160        self.stats.buffering_sessions.fetch_add(1, Ordering::Relaxed);
161
162        let _ = self.state_tx.send(BufferState::Buffering);
163
164        tracing::info!("Switchover buffer: started buffering");
165    }
166
167    /// Stop buffering (called when switchover completes or fails)
168    pub fn stop_buffering(&self) {
169        self.is_buffering.store(false, Ordering::SeqCst);
170        self.state.store(BufferState::Draining as u64, Ordering::SeqCst);
171
172        let duration = self.buffering_started.lock()
173            .map(|start| start.elapsed())
174            .unwrap_or_default();
175
176        let _ = self.state_tx.send(BufferState::Draining);
177
178        tracing::info!(
179            "Switchover buffer: stopped buffering after {:?}, {} queries buffered",
180            duration,
181            self.buffer.lock().len()
182        );
183    }
184
185    /// Buffer a query (returns receiver for result)
186    pub fn buffer_query(
187        &self,
188        sql: String,
189        params: Vec<Vec<u8>>,
190        client_id: u64,
191    ) -> Result<oneshot::Receiver<BufferResult>> {
192        // Check if we should buffer
193        if !self.is_buffering() {
194            return Err(ProxyError::Internal("Not in buffering mode".to_string()));
195        }
196
197        // Check timeout
198        if let Some(started) = *self.buffering_started.lock() {
199            if started.elapsed() > self.config.buffer_timeout {
200                return Err(ProxyError::Timeout("Buffer timeout exceeded".to_string()));
201            }
202        }
203
204        // Check capacity
205        let buffer_len = self.buffer.lock().len();
206        if buffer_len >= self.config.max_buffered_queries {
207            self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
208            return Err(ProxyError::PoolExhausted("Buffer full".to_string()));
209        }
210
211        // Check memory
212        let query_size = sql.len() + params.iter().map(|p| p.len()).sum::<usize>();
213        let current_memory = self.buffer_memory.load(Ordering::Relaxed) as usize;
214        if current_memory + query_size > self.config.max_buffer_memory {
215            self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
216            return Err(ProxyError::PoolExhausted("Buffer memory exhausted".to_string()));
217        }
218
219        // Create response channel
220        let (response_tx, response_rx) = oneshot::channel();
221
222        // Create buffered query
223        let buffered = BufferedQuery {
224            sql,
225            params,
226            buffered_at: Instant::now(),
227            response_tx,
228            client_id,
229        };
230
231        // Add to buffer
232        self.buffer.lock().push_back(buffered);
233        self.buffer_memory.fetch_add(query_size as u64, Ordering::Relaxed);
234        self.stats.buffered_queries.fetch_add(1, Ordering::Relaxed);
235
236        Ok(response_rx)
237    }
238
239    /// Drain buffer and replay queries to new primary
240    ///
241    /// The `execute_fn` is called for each buffered query to execute it
242    /// against the new primary.
243    pub async fn drain<F, Fut>(&self, execute_fn: F)
244    where
245        F: Fn(String, Vec<Vec<u8>>) -> Fut,
246        Fut: std::future::Future<Output = Result<()>>,
247    {
248        tracing::info!("Switchover buffer: draining buffer");
249
250        let queries: Vec<BufferedQuery> = {
251            let mut buffer = self.buffer.lock();
252            buffer.drain(..).collect()
253        };
254
255        self.buffer_memory.store(0, Ordering::Relaxed);
256
257        let total = queries.len();
258        let mut success = 0;
259        let mut failed = 0;
260        let mut timed_out = 0;
261
262        for query in queries {
263            // Check if query timed out while buffered
264            if query.buffered_at.elapsed() > self.config.buffer_timeout {
265                let _ = query.response_tx.send(BufferResult::Timeout);
266                timed_out += 1;
267                continue;
268            }
269
270            // Execute query
271            match execute_fn(query.sql, query.params).await {
272                Ok(()) => {
273                    let _ = query.response_tx.send(BufferResult::Success);
274                    success += 1;
275                }
276                Err(e) => {
277                    let _ = query.response_tx.send(BufferResult::Error(e.to_string()));
278                    failed += 1;
279                }
280            }
281        }
282
283        self.stats.replayed_queries.fetch_add(success, Ordering::Relaxed);
284        self.stats.failed_replays.fetch_add(failed, Ordering::Relaxed);
285        self.stats.timed_out_queries.fetch_add(timed_out, Ordering::Relaxed);
286
287        // Return to passthrough mode
288        self.state.store(BufferState::Passthrough as u64, Ordering::SeqCst);
289        let _ = self.state_tx.send(BufferState::Passthrough);
290
291        tracing::info!(
292            "Switchover buffer: drained {} queries (success: {}, failed: {}, timeout: {})",
293            total,
294            success,
295            failed,
296            timed_out
297        );
298    }
299
300    /// Fail all buffered queries (called if switchover fails)
301    pub fn fail_all(&self, error: &str) {
302        let queries: Vec<BufferedQuery> = {
303            let mut buffer = self.buffer.lock();
304            buffer.drain(..).collect()
305        };
306
307        let query_count = queries.len();
308        self.buffer_memory.store(0, Ordering::Relaxed);
309
310        for query in queries {
311            let _ = query.response_tx.send(BufferResult::SwitchoverFailed);
312        }
313
314        self.stats.failed_replays.fetch_add(query_count as u64, Ordering::Relaxed);
315
316        // Return to passthrough mode
317        self.state.store(BufferState::Passthrough as u64, Ordering::SeqCst);
318        let _ = self.state_tx.send(BufferState::Passthrough);
319
320        tracing::warn!(
321            "Switchover buffer: failed {} queries due to: {}",
322            query_count,
323            error
324        );
325    }
326
327    /// Get current buffer length
328    pub fn len(&self) -> usize {
329        self.buffer.lock().len()
330    }
331
332    /// Check if buffer is empty
333    pub fn is_empty(&self) -> bool {
334        self.buffer.lock().is_empty()
335    }
336
337    /// Get buffer statistics
338    pub fn stats(&self) -> BufferStatsSnapshot {
339        BufferStatsSnapshot {
340            buffering_sessions: self.stats.buffering_sessions.load(Ordering::Relaxed),
341            buffered_queries: self.stats.buffered_queries.load(Ordering::Relaxed),
342            replayed_queries: self.stats.replayed_queries.load(Ordering::Relaxed),
343            failed_replays: self.stats.failed_replays.load(Ordering::Relaxed),
344            timed_out_queries: self.stats.timed_out_queries.load(Ordering::Relaxed),
345            rejected_queries: self.stats.rejected_queries.load(Ordering::Relaxed),
346            current_buffer_size: self.buffer.lock().len(),
347            current_memory_usage: self.buffer_memory.load(Ordering::Relaxed) as usize,
348        }
349    }
350}
351
352/// Internal statistics (atomic counters)
353#[derive(Default)]
354struct BufferStats {
355    buffering_sessions: AtomicU64,
356    buffered_queries: AtomicU64,
357    replayed_queries: AtomicU64,
358    failed_replays: AtomicU64,
359    timed_out_queries: AtomicU64,
360    rejected_queries: AtomicU64,
361}
362
363/// Statistics snapshot
364#[derive(Debug, Clone)]
365pub struct BufferStatsSnapshot {
366    pub buffering_sessions: u64,
367    pub buffered_queries: u64,
368    pub replayed_queries: u64,
369    pub failed_replays: u64,
370    pub timed_out_queries: u64,
371    pub rejected_queries: u64,
372    pub current_buffer_size: usize,
373    pub current_memory_usage: usize,
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_buffer_state_transitions() {
382        let buffer = SwitchoverBuffer::new(BufferConfig::default());
383
384        assert_eq!(buffer.state(), BufferState::Passthrough);
385        assert!(!buffer.is_buffering());
386
387        buffer.start_buffering();
388        assert_eq!(buffer.state(), BufferState::Buffering);
389        assert!(buffer.is_buffering());
390
391        buffer.stop_buffering();
392        assert_eq!(buffer.state(), BufferState::Draining);
393        assert!(!buffer.is_buffering());
394    }
395
396    #[tokio::test]
397    async fn test_buffer_query() {
398        let buffer = SwitchoverBuffer::new(BufferConfig::default());
399
400        // Can't buffer when not in buffering mode
401        let result = buffer.buffer_query("SELECT 1".to_string(), vec![], 1);
402        assert!(result.is_err());
403
404        // Start buffering
405        buffer.start_buffering();
406
407        // Now can buffer
408        let rx = buffer.buffer_query("INSERT INTO t VALUES (1)".to_string(), vec![], 1).unwrap();
409        assert_eq!(buffer.len(), 1);
410
411        // Drain buffer
412        buffer.drain(|_sql, _params| async { Ok(()) }).await;
413
414        // Check result
415        let result = rx.await.unwrap();
416        assert!(matches!(result, BufferResult::Success));
417        assert!(buffer.is_empty());
418    }
419
420    #[test]
421    fn test_buffer_limits() {
422        let config = BufferConfig {
423            max_buffered_queries: 2,
424            ..Default::default()
425        };
426        let buffer = SwitchoverBuffer::new(config);
427        buffer.start_buffering();
428
429        // Buffer up to limit
430        let _ = buffer.buffer_query("Q1".to_string(), vec![], 1).unwrap();
431        let _ = buffer.buffer_query("Q2".to_string(), vec![], 2).unwrap();
432
433        // Third should fail
434        let result = buffer.buffer_query("Q3".to_string(), vec![], 3);
435        assert!(result.is_err());
436    }
437}