Skip to main content

aurora_db/storage/
write_buffer.rs

1use crate::error::{AqlError, Result};
2use crate::storage::ColdStore;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::mpsc;
6use std::time::{Duration, Instant};
7
8pub enum WriteOp {
9    Write { key: Arc<String>, value: Arc<Vec<u8>> },
10    Flush(mpsc::SyncSender<Result<()>>),
11    Shutdown,
12}
13
14pub struct WriteBuffer {
15    sender: mpsc::SyncSender<WriteOp>,
16    is_alive: Arc<AtomicBool>,
17    thread_handle: Option<std::thread::JoinHandle<()>>,
18}
19
20impl WriteBuffer {
21    pub fn new(cold: Arc<ColdStore>, buffer_size: usize, flush_interval_ms: u64) -> Self {
22        let (sender, receiver) = mpsc::sync_channel::<WriteOp>(1000);
23        let is_alive = Arc::new(AtomicBool::new(true));
24        let task_is_alive = Arc::clone(&is_alive);
25
26        // Use a real OS thread instead of tokio::spawn
27        // This allows Drop to safely block without async context issues
28        let handle = std::thread::spawn(move || {
29            struct TaskGuard(Arc<AtomicBool>);
30            impl Drop for TaskGuard {
31                fn drop(&mut self) {
32                    self.0.store(false, Ordering::SeqCst);
33                }
34            }
35            let _guard = TaskGuard(task_is_alive);
36
37            let mut batch = Vec::with_capacity(buffer_size);
38            let mut last_flush = Instant::now();
39            let flush_duration = Duration::from_millis(flush_interval_ms);
40
41            loop {
42                // Try to receive with a timeout to allow periodic flushes
43                let timeout = flush_duration.saturating_sub(last_flush.elapsed());
44                let op = receiver.recv_timeout(timeout);
45
46                match op {
47                    Ok(WriteOp::Write { key, value }) => {
48                        batch.push((key, value));
49
50                        // Flush if batch is full
51                        if batch.len() >= buffer_size {
52                            let batch_to_write = std::mem::take(&mut batch);
53                            if let Err(e) = cold.batch_set_arc(batch_to_write) {
54                                eprintln!("Write buffer flush error: {}", e);
55                            }
56                            last_flush = Instant::now();
57                        }
58                    }
59                    Ok(WriteOp::Flush(response)) => {
60                        let result = if !batch.is_empty() {
61                            let batch_to_write = std::mem::take(&mut batch);
62                            cold.batch_set_arc(batch_to_write)
63                        } else {
64                            Ok(())
65                        };
66                        let _ = response.send(result);
67                        last_flush = Instant::now();
68                    }
69                    Ok(WriteOp::Shutdown) => {
70                        // Final flush before shutdown
71                        if !batch.is_empty() {
72                            let batch_to_write = std::mem::take(&mut batch);
73                            if let Err(e) = cold.batch_set_arc(batch_to_write) {
74                                eprintln!("Write buffer shutdown flush error: {}", e);
75                            }
76                        }
77                        break;
78                    }
79                    Err(mpsc::RecvTimeoutError::Timeout) => {
80                        // Periodic flush
81                        if !batch.is_empty() && last_flush.elapsed() >= flush_duration {
82                            let batch_to_write = std::mem::take(&mut batch);
83                            
84                            // Use a match to handle the error cleanly
85                            match cold.batch_set_arc(batch_to_write) {
86                                Ok(_) => last_flush = Instant::now(),
87                                Err(_) => {
88                                    eprintln!("Write buffer periodic flush error: Disk Full. Pausing writes.");
89                                    // PAUSE OR RETURN ERROR
90                                    std::thread::sleep(Duration::from_millis(100)); 
91                                }
92                            }
93                        }
94                    }
95                    Err(mpsc::RecvTimeoutError::Disconnected) => {
96                        // Channel closed, flush and exit
97                        if !batch.is_empty() {
98                            let batch_to_write = std::mem::take(&mut batch);
99                            if let Err(e) = cold.batch_set_arc(batch_to_write) {
100                                eprintln!("Write buffer final flush error: {}", e);
101                            }
102                        }
103                        break;
104                    }
105                }
106            }
107        });
108
109        Self {
110            sender,
111            is_alive,
112            thread_handle: Some(handle),
113        }
114    }
115
116    pub fn write(&self, key: Arc<String>, value: Arc<Vec<u8>>) -> Result<()> {
117        if !self.is_alive.load(Ordering::SeqCst) {
118            return Err(AqlError::invalid_operation(
119                "Write buffer is not active.".to_string(),
120            ));
121        }
122        self.sender
123            .send(WriteOp::Write { key, value })
124            .map_err(|_| {
125                AqlError::invalid_operation("Write buffer channel closed unexpectedly.".to_string())
126            })?;
127        Ok(())
128    }
129
130    pub fn flush(&self) -> Result<()> {
131        let (tx, rx) = mpsc::sync_channel(1);
132        self.sender
133            .send(WriteOp::Flush(tx))
134            .map_err(|_| AqlError::invalid_operation("Write buffer closed".to_string()))?;
135
136        rx.recv()
137            .map_err(|_| AqlError::invalid_operation("Flush response lost".to_string()))?
138    }
139
140    pub fn is_active(&self) -> bool {
141        self.is_alive.load(Ordering::SeqCst)
142    }
143}
144
145impl Drop for WriteBuffer {
146    fn drop(&mut self) {
147        let _ = self.sender.send(WriteOp::Shutdown);
148
149        // HACK: Join thread to prevent Windows zombie process (causing LNK1104 on next build)
150        if let Some(handle) = self.thread_handle.take() {
151            let _ = handle.join();
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use tempfile::tempdir;
160
161    #[tokio::test]
162    async fn test_write_buffer() -> Result<()> {
163        let temp_dir = tempdir()?;
164        let db_path = temp_dir.path().join("test.db");
165        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
166
167        let buffer = WriteBuffer::new(Arc::clone(&cold), 100, 50);
168
169        buffer.write(Arc::new("key1".to_string()), Arc::new(b"value1".to_vec()))?;
170        buffer.write(Arc::new("key2".to_string()), Arc::new(b"value2".to_vec()))?;
171        buffer.write(Arc::new("key3".to_string()), Arc::new(b"value3".to_vec()))?;
172
173        // Explicitly flush to ensure data is written
174        buffer.flush()?;
175
176        assert_eq!(cold.get("key1")?, Some(b"value1".to_vec()));
177        assert_eq!(cold.get("key2")?, Some(b"value2".to_vec()));
178        assert_eq!(cold.get("key3")?, Some(b"value3".to_vec()));
179
180        Ok(())
181    }
182
183    #[tokio::test]
184    async fn test_write_buffer_batch_flush() -> Result<()> {
185        let temp_dir = tempdir()?;
186        let db_path = temp_dir.path().join("test.db");
187        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
188
189        let buffer = WriteBuffer::new(Arc::clone(&cold), 5, 1000);
190
191        for i in 0..10 {
192            buffer.write(Arc::new(format!("key{}", i)), Arc::new(format!("value{}", i).into_bytes()))?;
193        }
194
195        // Wait for flush interval (1000ms) plus some buffer
196        tokio::time::sleep(Duration::from_millis(1500)).await;
197
198        for i in 0..10 {
199            assert_eq!(
200                cold.get(&format!("key{}", i))?,
201                Some(format!("value{}", i).into_bytes())
202            );
203        }
204
205        Ok(())
206    }
207}