Skip to main content

aurora_db/storage/
write_buffer.rs

1use crate::error::{AqlError, Result};
2use crate::storage::ColdStore;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::mpsc;
6use std::time::{Duration, Instant};
7
8pub enum WriteOp {
9    Write {
10        key: Arc<String>,
11        value: Arc<Vec<u8>>,
12    },
13    Flush(mpsc::SyncSender<Result<()>>),
14    Shutdown,
15}
16
17pub struct WriteBuffer {
18    sender: mpsc::SyncSender<WriteOp>,
19    is_alive: Arc<AtomicBool>,
20    thread_handle: Option<std::thread::JoinHandle<()>>,
21}
22
23impl WriteBuffer {
24    pub fn new(cold: Arc<ColdStore>, buffer_size: usize, flush_interval_ms: u64) -> Self {
25        let (sender, receiver) = mpsc::sync_channel::<WriteOp>(1000);
26        let is_alive = Arc::new(AtomicBool::new(true));
27        let task_is_alive = Arc::clone(&is_alive);
28
29        // Use a real OS thread instead of tokio::spawn
30        // This allows Drop to safely block without async context issues
31        let handle = std::thread::spawn(move || {
32            struct TaskGuard(Arc<AtomicBool>);
33            impl Drop for TaskGuard {
34                fn drop(&mut self) {
35                    self.0.store(false, Ordering::SeqCst);
36                }
37            }
38            let _guard = TaskGuard(task_is_alive);
39
40            let mut batch = Vec::with_capacity(buffer_size);
41            let mut last_flush = Instant::now();
42            let flush_duration = Duration::from_millis(flush_interval_ms);
43
44            loop {
45                // Try to receive with a timeout to allow periodic flushes
46                let timeout = flush_duration.saturating_sub(last_flush.elapsed());
47                let op = receiver.recv_timeout(timeout);
48
49                match op {
50                    Ok(WriteOp::Write { key, value }) => {
51                        batch.push((key, value));
52
53                        // Flush if batch is full
54                        if batch.len() >= buffer_size {
55                            let batch_to_write = std::mem::take(&mut batch);
56                            if let Err(e) = cold.batch_set_arc(batch_to_write) {
57                                eprintln!("Write buffer flush error: {}", e);
58                            }
59                            last_flush = Instant::now();
60                        }
61                    }
62                    Ok(WriteOp::Flush(response)) => {
63                        let result = if !batch.is_empty() {
64                            let batch_to_write = std::mem::take(&mut batch);
65                            cold.batch_set_arc(batch_to_write)
66                        } else {
67                            Ok(())
68                        };
69                        let _ = response.send(result);
70                        last_flush = Instant::now();
71                    }
72                    Ok(WriteOp::Shutdown) => {
73                        // Final flush before shutdown
74                        if !batch.is_empty() {
75                            let batch_to_write = std::mem::take(&mut batch);
76                            if let Err(e) = cold.batch_set_arc(batch_to_write) {
77                                eprintln!("Write buffer shutdown flush error: {}", e);
78                            }
79                        }
80                        break;
81                    }
82                    Err(mpsc::RecvTimeoutError::Timeout) => {
83                        // Periodic flush
84                        if !batch.is_empty() && last_flush.elapsed() >= flush_duration {
85                            let batch_to_write = std::mem::take(&mut batch);
86
87                            // Use a match to handle the error cleanly
88                            match cold.batch_set_arc(batch_to_write) {
89                                Ok(_) => last_flush = Instant::now(),
90                                Err(_) => {
91                                    eprintln!(
92                                        "Write buffer periodic flush error: Disk Full. Pausing writes."
93                                    );
94                                    // PAUSE OR RETURN ERROR
95                                    std::thread::sleep(Duration::from_millis(100));
96                                }
97                            }
98                        }
99                    }
100                    Err(mpsc::RecvTimeoutError::Disconnected) => {
101                        // Channel closed, flush and exit
102                        if !batch.is_empty() {
103                            let batch_to_write = std::mem::take(&mut batch);
104                            if let Err(e) = cold.batch_set_arc(batch_to_write) {
105                                eprintln!("Write buffer final flush error: {}", e);
106                            }
107                        }
108                        break;
109                    }
110                }
111            }
112        });
113
114        Self {
115            sender,
116            is_alive,
117            thread_handle: Some(handle),
118        }
119    }
120
121    pub fn write(&self, key: Arc<String>, value: Arc<Vec<u8>>) -> Result<()> {
122        if !self.is_alive.load(Ordering::SeqCst) {
123            return Err(AqlError::invalid_operation(
124                "Write buffer is not active.".to_string(),
125            ));
126        }
127        self.sender
128            .send(WriteOp::Write { key, value })
129            .map_err(|_| {
130                AqlError::invalid_operation("Write buffer channel closed unexpectedly.".to_string())
131            })?;
132        Ok(())
133    }
134
135    pub fn flush(&self) -> Result<()> {
136        let (tx, rx) = mpsc::sync_channel(1);
137        self.sender
138            .send(WriteOp::Flush(tx))
139            .map_err(|_| AqlError::invalid_operation("Write buffer closed".to_string()))?;
140
141        rx.recv()
142            .map_err(|_| AqlError::invalid_operation("Flush response lost".to_string()))?
143    }
144
145    pub fn is_active(&self) -> bool {
146        self.is_alive.load(Ordering::SeqCst)
147    }
148}
149
150impl Drop for WriteBuffer {
151    fn drop(&mut self) {
152        let _ = self.sender.send(WriteOp::Shutdown);
153
154        // HACK: Join thread to prevent Windows zombie process (causing LNK1104 on next build)
155        if let Some(handle) = self.thread_handle.take() {
156            let _ = handle.join();
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use tempfile::tempdir;
165
166    #[tokio::test]
167    async fn test_write_buffer() -> Result<()> {
168        let temp_dir = tempdir()?;
169        let db_path = temp_dir.path().join("test.db");
170        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
171
172        let buffer = WriteBuffer::new(Arc::clone(&cold), 100, 50);
173
174        buffer.write(Arc::new("key1".to_string()), Arc::new(b"value1".to_vec()))?;
175        buffer.write(Arc::new("key2".to_string()), Arc::new(b"value2".to_vec()))?;
176        buffer.write(Arc::new("key3".to_string()), Arc::new(b"value3".to_vec()))?;
177
178        // Explicitly flush to ensure data is written
179        buffer.flush()?;
180
181        assert_eq!(cold.get("key1")?, Some(b"value1".to_vec()));
182        assert_eq!(cold.get("key2")?, Some(b"value2".to_vec()));
183        assert_eq!(cold.get("key3")?, Some(b"value3".to_vec()));
184
185        Ok(())
186    }
187
188    #[tokio::test]
189    async fn test_write_buffer_batch_flush() -> Result<()> {
190        let temp_dir = tempdir()?;
191        let db_path = temp_dir.path().join("test.db");
192        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
193
194        let buffer = WriteBuffer::new(Arc::clone(&cold), 5, 1000);
195
196        for i in 0..10 {
197            buffer.write(
198                Arc::new(format!("key{}", i)),
199                Arc::new(format!("value{}", i).into_bytes()),
200            )?;
201        }
202
203        // Wait for flush interval (1000ms) plus some buffer
204        tokio::time::sleep(Duration::from_millis(1500)).await;
205
206        for i in 0..10 {
207            assert_eq!(
208                cold.get(&format!("key{}", i))?,
209                Some(format!("value{}", i).into_bytes())
210            );
211        }
212
213        Ok(())
214    }
215}