aurora_db/storage/
write_buffer.rs

1use crate::error::{AuroraError, Result};
2use crate::storage::ColdStore;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::mpsc;
6use std::time::{Duration, Instant};
7
8pub enum WriteOp {
9    Write { key: String, value: Vec<u8> },
10    Flush(mpsc::SyncSender<Result<()>>),
11    Shutdown,
12}
13
14pub struct WriteBuffer {
15    sender: mpsc::SyncSender<WriteOp>,
16    is_alive: Arc<AtomicBool>,
17}
18
19impl WriteBuffer {
20    pub fn new(cold: Arc<ColdStore>, buffer_size: usize, flush_interval_ms: u64) -> Self {
21        let (sender, receiver) = mpsc::sync_channel::<WriteOp>(1000);
22        let is_alive = Arc::new(AtomicBool::new(true));
23        let task_is_alive = Arc::clone(&is_alive);
24
25        // Use a real OS thread instead of tokio::spawn
26        // This allows Drop to safely block without async context issues
27        std::thread::spawn(move || {
28            struct TaskGuard(Arc<AtomicBool>);
29            impl Drop for TaskGuard {
30                fn drop(&mut self) {
31                    self.0.store(false, Ordering::SeqCst);
32                }
33            }
34            let _guard = TaskGuard(task_is_alive);
35
36            let mut batch = Vec::with_capacity(buffer_size);
37            let mut last_flush = Instant::now();
38            let flush_duration = Duration::from_millis(flush_interval_ms);
39
40            loop {
41                // Try to receive with a timeout to allow periodic flushes
42                let timeout = flush_duration.saturating_sub(last_flush.elapsed());
43                let op = receiver.recv_timeout(timeout);
44
45                match op {
46                    Ok(WriteOp::Write { key, value }) => {
47                        batch.push((key, value));
48
49                        // Flush if batch is full
50                        if batch.len() >= buffer_size {
51                            if let Err(e) = cold.batch_set(std::mem::take(&mut batch)) {
52                                eprintln!("Write buffer flush error: {}", e);
53                            }
54                            last_flush = Instant::now();
55                        }
56                    }
57                    Ok(WriteOp::Flush(response)) => {
58                        let result = if !batch.is_empty() {
59                            cold.batch_set(std::mem::take(&mut batch))
60                        } else {
61                            Ok(())
62                        };
63                        let _ = response.send(result);
64                        last_flush = Instant::now();
65                    }
66                    Ok(WriteOp::Shutdown) => {
67                        // Final flush before shutdown
68                        if !batch.is_empty()
69                            && let Err(e) = cold.batch_set(batch) {
70                                eprintln!("Write buffer shutdown flush error: {}", e);
71                            }
72                        break; // Exit gracefully
73                    }
74                    Err(mpsc::RecvTimeoutError::Timeout) => {
75                        // Periodic flush
76                        if !batch.is_empty() && last_flush.elapsed() >= flush_duration {
77                            if let Err(e) = cold.batch_set(std::mem::take(&mut batch)) {
78                                eprintln!("Write buffer periodic flush error: {}", e);
79                            }
80                            last_flush = Instant::now();
81                        }
82                    }
83                    Err(mpsc::RecvTimeoutError::Disconnected) => {
84                        // Channel closed, flush and exit
85                        if !batch.is_empty()
86                            && let Err(e) = cold.batch_set(batch) {
87                                eprintln!("Write buffer final flush error: {}", e);
88                            }
89                        break;
90                    }
91                }
92            }
93        });
94
95        Self { sender, is_alive }
96    }
97
98    pub fn write(&self, key: String, value: Vec<u8>) -> Result<()> {
99        if !self.is_alive.load(Ordering::SeqCst) {
100            return Err(AuroraError::InvalidOperation(
101                "Write buffer is not active.".into(),
102            ));
103        }
104        self.sender.send(WriteOp::Write { key, value }).map_err(|_| {
105            AuroraError::InvalidOperation("Write buffer channel closed unexpectedly.".into())
106        })?;
107        Ok(())
108    }
109
110    pub fn flush(&self) -> Result<()> {
111        let (tx, rx) = mpsc::sync_channel(1);
112        self.sender
113            .send(WriteOp::Flush(tx))
114            .map_err(|_| AuroraError::InvalidOperation("Write buffer closed".into()))?;
115
116        rx.recv()
117            .map_err(|_| AuroraError::InvalidOperation("Flush response lost".into()))?
118    }
119
120    pub fn is_active(&self) -> bool {
121        self.is_alive.load(Ordering::SeqCst)
122    }
123}
124
125impl Drop for WriteBuffer {
126    fn drop(&mut self) {
127        // Signal graceful shutdown
128        // The background thread (not tokio task) will perform final flush
129        // We can safely send because we're using std::sync::mpsc
130        let _ = self.sender.send(WriteOp::Shutdown);
131
132        // Give the background thread a moment to flush
133        // The thread will exit cleanly after flushing
134        std::thread::sleep(std::time::Duration::from_millis(50));
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use tempfile::tempdir;
142
143    #[tokio::test]
144    async fn test_write_buffer() -> Result<()> {
145        let temp_dir = tempdir()?;
146        let db_path = temp_dir.path().join("test.db");
147        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
148
149        let buffer = WriteBuffer::new(Arc::clone(&cold), 100, 50);
150
151        buffer.write("key1".to_string(), b"value1".to_vec())?;
152        buffer.write("key2".to_string(), b"value2".to_vec())?;
153        buffer.write("key3".to_string(), b"value3".to_vec())?;
154
155        // Explicitly flush to ensure data is written
156        buffer.flush()?;
157
158        assert_eq!(cold.get("key1")?, Some(b"value1".to_vec()));
159        assert_eq!(cold.get("key2")?, Some(b"value2".to_vec()));
160        assert_eq!(cold.get("key3")?, Some(b"value3".to_vec()));
161
162        Ok(())
163    }
164
165    #[tokio::test]
166    async fn test_write_buffer_batch_flush() -> Result<()> {
167        let temp_dir = tempdir()?;
168        let db_path = temp_dir.path().join("test.db");
169        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
170
171        let buffer = WriteBuffer::new(Arc::clone(&cold), 5, 1000);
172
173        for i in 0..10 {
174            buffer.write(format!("key{}", i), format!("value{}", i).into_bytes())?;
175        }
176
177        // Wait for flush interval (1000ms) plus some buffer
178        tokio::time::sleep(Duration::from_millis(1500)).await;
179
180        for i in 0..10 {
181            assert_eq!(
182                cold.get(&format!("key{}", i))?,
183                Some(format!("value{}", i).into_bytes())
184            );
185        }
186
187        Ok(())
188    }
189}