Skip to main content

heliosdb_proxy/
switchover_buffer.rs

1//! Switchover Buffer - Query buffering during controlled switchover
2//!
3//! Buffers write queries during the brief switchover window to ensure
4//! zero transaction loss. Queries are replayed to the new primary
5//! once switchover completes.
6//!
7//! ## How it works
8//!
9//! ```text
10//! Normal Operation:
11//!   Client → Proxy → Primary
12//!
13//! During Switchover:
14//!   Client → Proxy → Buffer (queued)
15//!                      ↓
16//!   [Switchover completes]
17//!                      ↓
18//!            Buffer → New Primary (replayed)
19//! ```
20//!
21//! ## Timeout Behavior
22//!
23//! If switchover takes longer than `buffer_timeout`, buffered queries
24//! will fail with a timeout error rather than blocking indefinitely.
25
26use parking_lot::Mutex;
27use std::collections::VecDeque;
28use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
29use std::time::{Duration, Instant};
30use tokio::sync::{broadcast, oneshot};
31
32use super::{ProxyError, Result};
33
34/// Buffer configuration
35#[derive(Debug, Clone)]
36pub struct BufferConfig {
37    /// Maximum time to buffer queries (default: 5s)
38    pub buffer_timeout: Duration,
39    /// Maximum number of queries to buffer (default: 10000)
40    pub max_buffered_queries: usize,
41    /// Maximum memory for buffered queries (default: 100MB)
42    pub max_buffer_memory: usize,
43    /// Whether to allow new queries while draining buffer
44    pub allow_queries_during_drain: bool,
45}
46
47impl Default for BufferConfig {
48    fn default() -> Self {
49        Self {
50            buffer_timeout: Duration::from_secs(5),
51            max_buffered_queries: 10000,
52            max_buffer_memory: 100 * 1024 * 1024, // 100MB
53            allow_queries_during_drain: true,
54        }
55    }
56}
57
58/// A buffered query waiting to be executed
59#[derive(Debug)]
60pub struct BufferedQuery {
61    /// SQL statement
62    pub sql: String,
63    /// Query parameters
64    pub params: Vec<Vec<u8>>,
65    /// Time when query was buffered
66    pub buffered_at: Instant,
67    /// Channel to send result back to client
68    pub response_tx: oneshot::Sender<BufferResult>,
69    /// Client identifier (for logging/debugging)
70    pub client_id: u64,
71}
72
73/// Result of a buffered query after replay
74#[derive(Debug)]
75pub enum BufferResult {
76    /// Query executed successfully
77    Success,
78    /// Query failed with error
79    Error(String),
80    /// Query timed out while buffered
81    Timeout,
82    /// Switchover was cancelled/failed
83    SwitchoverFailed,
84}
85
86/// Buffer state
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum BufferState {
89    /// Normal operation - no buffering
90    Passthrough,
91    /// Buffering writes (during switchover)
92    Buffering,
93    /// Draining buffer to new primary
94    Draining,
95}
96
97/// Switchover buffer for zero-downtime primary transitions
98pub struct SwitchoverBuffer {
99    /// Configuration
100    config: BufferConfig,
101    /// Current state
102    state: AtomicU64, // BufferState as u64
103    /// Is buffering active
104    is_buffering: AtomicBool,
105    /// Buffered queries
106    buffer: Mutex<VecDeque<BufferedQuery>>,
107    /// Current buffer memory usage
108    buffer_memory: AtomicU64,
109    /// Time when buffering started
110    buffering_started: Mutex<Option<Instant>>,
111    /// Statistics
112    stats: BufferStats,
113    /// State change broadcaster
114    state_tx: broadcast::Sender<BufferState>,
115}
116
117impl SwitchoverBuffer {
118    /// Create a new switchover buffer
119    pub fn new(config: BufferConfig) -> Self {
120        let (state_tx, _) = broadcast::channel(16);
121
122        Self {
123            config,
124            state: AtomicU64::new(BufferState::Passthrough as u64),
125            is_buffering: AtomicBool::new(false),
126            buffer: Mutex::new(VecDeque::new()),
127            buffer_memory: AtomicU64::new(0),
128            buffering_started: Mutex::new(None),
129            stats: BufferStats::default(),
130            state_tx,
131        }
132    }
133
134    /// Check if currently buffering
135    pub fn is_buffering(&self) -> bool {
136        self.is_buffering.load(Ordering::SeqCst)
137    }
138
139    /// Get current state
140    pub fn state(&self) -> BufferState {
141        match self.state.load(Ordering::SeqCst) {
142            0 => BufferState::Passthrough,
143            1 => BufferState::Buffering,
144            2 => BufferState::Draining,
145            _ => BufferState::Passthrough,
146        }
147    }
148
149    /// Subscribe to state changes
150    pub fn subscribe(&self) -> broadcast::Receiver<BufferState> {
151        self.state_tx.subscribe()
152    }
153
154    /// Start buffering (called when switchover begins)
155    pub fn start_buffering(&self) {
156        self.is_buffering.store(true, Ordering::SeqCst);
157        self.state
158            .store(BufferState::Buffering as u64, Ordering::SeqCst);
159        *self.buffering_started.lock() = Some(Instant::now());
160
161        self.stats
162            .buffering_sessions
163            .fetch_add(1, Ordering::Relaxed);
164
165        let _ = self.state_tx.send(BufferState::Buffering);
166
167        tracing::info!("Switchover buffer: started buffering");
168    }
169
170    /// Stop buffering (called when switchover completes or fails)
171    pub fn stop_buffering(&self) {
172        self.is_buffering.store(false, Ordering::SeqCst);
173        self.state
174            .store(BufferState::Draining as u64, Ordering::SeqCst);
175
176        let duration = self
177            .buffering_started
178            .lock()
179            .map(|start| start.elapsed())
180            .unwrap_or_default();
181
182        let _ = self.state_tx.send(BufferState::Draining);
183
184        tracing::info!(
185            "Switchover buffer: stopped buffering after {:?}, {} queries buffered",
186            duration,
187            self.buffer.lock().len()
188        );
189    }
190
191    /// Buffer a query (returns receiver for result)
192    pub fn buffer_query(
193        &self,
194        sql: String,
195        params: Vec<Vec<u8>>,
196        client_id: u64,
197    ) -> Result<oneshot::Receiver<BufferResult>> {
198        // Check if we should buffer
199        if !self.is_buffering() {
200            return Err(ProxyError::Internal("Not in buffering mode".to_string()));
201        }
202
203        // Check timeout
204        if let Some(started) = *self.buffering_started.lock() {
205            if started.elapsed() > self.config.buffer_timeout {
206                return Err(ProxyError::Timeout("Buffer timeout exceeded".to_string()));
207            }
208        }
209
210        // Check capacity
211        let buffer_len = self.buffer.lock().len();
212        if buffer_len >= self.config.max_buffered_queries {
213            self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
214            return Err(ProxyError::PoolExhausted("Buffer full".to_string()));
215        }
216
217        // Check memory
218        let query_size = sql.len() + params.iter().map(|p| p.len()).sum::<usize>();
219        let current_memory = self.buffer_memory.load(Ordering::Relaxed) as usize;
220        if current_memory + query_size > self.config.max_buffer_memory {
221            self.stats.rejected_queries.fetch_add(1, Ordering::Relaxed);
222            return Err(ProxyError::PoolExhausted(
223                "Buffer memory exhausted".to_string(),
224            ));
225        }
226
227        // Create response channel
228        let (response_tx, response_rx) = oneshot::channel();
229
230        // Create buffered query
231        let buffered = BufferedQuery {
232            sql,
233            params,
234            buffered_at: Instant::now(),
235            response_tx,
236            client_id,
237        };
238
239        // Add to buffer
240        self.buffer.lock().push_back(buffered);
241        self.buffer_memory
242            .fetch_add(query_size as u64, Ordering::Relaxed);
243        self.stats.buffered_queries.fetch_add(1, Ordering::Relaxed);
244
245        Ok(response_rx)
246    }
247
248    /// Drain buffer and replay queries to new primary
249    ///
250    /// The `execute_fn` is called for each buffered query to execute it
251    /// against the new primary.
252    pub async fn drain<F, Fut>(&self, execute_fn: F)
253    where
254        F: Fn(String, Vec<Vec<u8>>) -> Fut,
255        Fut: std::future::Future<Output = Result<()>>,
256    {
257        tracing::info!("Switchover buffer: draining buffer");
258
259        let queries: Vec<BufferedQuery> = {
260            let mut buffer = self.buffer.lock();
261            buffer.drain(..).collect()
262        };
263
264        self.buffer_memory.store(0, Ordering::Relaxed);
265
266        let total = queries.len();
267        let mut success = 0;
268        let mut failed = 0;
269        let mut timed_out = 0;
270
271        for query in queries {
272            // Check if query timed out while buffered
273            if query.buffered_at.elapsed() > self.config.buffer_timeout {
274                let _ = query.response_tx.send(BufferResult::Timeout);
275                timed_out += 1;
276                continue;
277            }
278
279            // Execute query
280            match execute_fn(query.sql, query.params).await {
281                Ok(()) => {
282                    let _ = query.response_tx.send(BufferResult::Success);
283                    success += 1;
284                }
285                Err(e) => {
286                    let _ = query.response_tx.send(BufferResult::Error(e.to_string()));
287                    failed += 1;
288                }
289            }
290        }
291
292        self.stats
293            .replayed_queries
294            .fetch_add(success, Ordering::Relaxed);
295        self.stats
296            .failed_replays
297            .fetch_add(failed, Ordering::Relaxed);
298        self.stats
299            .timed_out_queries
300            .fetch_add(timed_out, Ordering::Relaxed);
301
302        // Return to passthrough mode
303        self.state
304            .store(BufferState::Passthrough as u64, Ordering::SeqCst);
305        let _ = self.state_tx.send(BufferState::Passthrough);
306
307        tracing::info!(
308            "Switchover buffer: drained {} queries (success: {}, failed: {}, timeout: {})",
309            total,
310            success,
311            failed,
312            timed_out
313        );
314    }
315
316    /// Fail all buffered queries (called if switchover fails)
317    pub fn fail_all(&self, error: &str) {
318        let queries: Vec<BufferedQuery> = {
319            let mut buffer = self.buffer.lock();
320            buffer.drain(..).collect()
321        };
322
323        let query_count = queries.len();
324        self.buffer_memory.store(0, Ordering::Relaxed);
325
326        for query in queries {
327            let _ = query.response_tx.send(BufferResult::SwitchoverFailed);
328        }
329
330        self.stats
331            .failed_replays
332            .fetch_add(query_count as u64, Ordering::Relaxed);
333
334        // Return to passthrough mode
335        self.state
336            .store(BufferState::Passthrough as u64, Ordering::SeqCst);
337        let _ = self.state_tx.send(BufferState::Passthrough);
338
339        tracing::warn!(
340            "Switchover buffer: failed {} queries due to: {}",
341            query_count,
342            error
343        );
344    }
345
346    /// Get current buffer length
347    pub fn len(&self) -> usize {
348        self.buffer.lock().len()
349    }
350
351    /// Check if buffer is empty
352    pub fn is_empty(&self) -> bool {
353        self.buffer.lock().is_empty()
354    }
355
356    /// Get buffer statistics
357    pub fn stats(&self) -> BufferStatsSnapshot {
358        BufferStatsSnapshot {
359            buffering_sessions: self.stats.buffering_sessions.load(Ordering::Relaxed),
360            buffered_queries: self.stats.buffered_queries.load(Ordering::Relaxed),
361            replayed_queries: self.stats.replayed_queries.load(Ordering::Relaxed),
362            failed_replays: self.stats.failed_replays.load(Ordering::Relaxed),
363            timed_out_queries: self.stats.timed_out_queries.load(Ordering::Relaxed),
364            rejected_queries: self.stats.rejected_queries.load(Ordering::Relaxed),
365            current_buffer_size: self.buffer.lock().len(),
366            current_memory_usage: self.buffer_memory.load(Ordering::Relaxed) as usize,
367        }
368    }
369}
370
371/// Internal statistics (atomic counters)
372#[derive(Default)]
373struct BufferStats {
374    buffering_sessions: AtomicU64,
375    buffered_queries: AtomicU64,
376    replayed_queries: AtomicU64,
377    failed_replays: AtomicU64,
378    timed_out_queries: AtomicU64,
379    rejected_queries: AtomicU64,
380}
381
382/// Statistics snapshot
383#[derive(Debug, Clone)]
384pub struct BufferStatsSnapshot {
385    pub buffering_sessions: u64,
386    pub buffered_queries: u64,
387    pub replayed_queries: u64,
388    pub failed_replays: u64,
389    pub timed_out_queries: u64,
390    pub rejected_queries: u64,
391    pub current_buffer_size: usize,
392    pub current_memory_usage: usize,
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_buffer_state_transitions() {
401        let buffer = SwitchoverBuffer::new(BufferConfig::default());
402
403        assert_eq!(buffer.state(), BufferState::Passthrough);
404        assert!(!buffer.is_buffering());
405
406        buffer.start_buffering();
407        assert_eq!(buffer.state(), BufferState::Buffering);
408        assert!(buffer.is_buffering());
409
410        buffer.stop_buffering();
411        assert_eq!(buffer.state(), BufferState::Draining);
412        assert!(!buffer.is_buffering());
413    }
414
415    #[tokio::test]
416    async fn test_buffer_query() {
417        let buffer = SwitchoverBuffer::new(BufferConfig::default());
418
419        // Can't buffer when not in buffering mode
420        let result = buffer.buffer_query("SELECT 1".to_string(), vec![], 1);
421        assert!(result.is_err());
422
423        // Start buffering
424        buffer.start_buffering();
425
426        // Now can buffer
427        let rx = buffer
428            .buffer_query("INSERT INTO t VALUES (1)".to_string(), vec![], 1)
429            .unwrap();
430        assert_eq!(buffer.len(), 1);
431
432        // Drain buffer
433        buffer.drain(|_sql, _params| async { Ok(()) }).await;
434
435        // Check result
436        let result = rx.await.unwrap();
437        assert!(matches!(result, BufferResult::Success));
438        assert!(buffer.is_empty());
439    }
440
441    #[test]
442    fn test_buffer_limits() {
443        let config = BufferConfig {
444            max_buffered_queries: 2,
445            ..Default::default()
446        };
447        let buffer = SwitchoverBuffer::new(config);
448        buffer.start_buffering();
449
450        // Buffer up to limit
451        let _ = buffer.buffer_query("Q1".to_string(), vec![], 1).unwrap();
452        let _ = buffer.buffer_query("Q2".to_string(), vec![], 2).unwrap();
453
454        // Third should fail
455        let result = buffer.buffer_query("Q3".to_string(), vec![], 3);
456        assert!(result.is_err());
457    }
458}