aurora_db/storage/
write_buffer.rs

1use crate::error::{AuroraError, Result};
2use crate::storage::ColdStore;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use tokio::sync::mpsc;
6use tokio::time::{Duration, interval};
7
8#[derive(Debug, Clone)]
9pub struct WriteOp {
10    pub key: String,
11    pub value: Vec<u8>,
12}
13
14pub struct WriteBuffer {
15    sender: mpsc::UnboundedSender<WriteOp>,
16    is_alive: Arc<AtomicBool>,
17}
18
19impl WriteBuffer {
20    pub fn new(cold: Arc<ColdStore>, buffer_size: usize, flush_interval_ms: u64) -> Self {
21        let (sender, mut receiver) = mpsc::unbounded_channel::<WriteOp>();
22        let is_alive = Arc::new(AtomicBool::new(true));
23        let task_is_alive = Arc::clone(&is_alive);
24
25        tokio::spawn(async move {
26            struct TaskGuard(Arc<AtomicBool>);
27            impl Drop for TaskGuard {
28                fn drop(&mut self) {
29                    self.0.store(false, Ordering::SeqCst);
30                }
31            }
32            let _guard = TaskGuard(task_is_alive);
33
34            let mut batch = Vec::with_capacity(buffer_size);
35            let mut flush_timer = interval(Duration::from_millis(flush_interval_ms));
36
37            loop {
38                tokio::select! {
39                    Some(op) = receiver.recv() => {
40                        batch.push((op.key, op.value));
41
42                        if batch.len() >= buffer_size {
43                            if let Err(e) = cold.batch_set(batch.drain(..).collect()) {
44                                eprintln!("Write buffer flush error: {}", e);
45                            }
46                        }
47                    }
48
49                    _ = flush_timer.tick() => {
50                        if !batch.is_empty() {
51                            if let Err(e) = cold.batch_set(batch.drain(..).collect()) {
52                                eprintln!("Write buffer periodic flush error: {}", e);
53                            }
54                        }
55                    }
56
57                    else => break,
58                }
59            }
60
61            if !batch.is_empty() {
62                if let Err(e) = cold.batch_set(batch) {
63                    eprintln!("Write buffer final flush error: {}", e);
64                }
65            }
66        });
67
68        Self { sender, is_alive }
69    }
70
71    pub fn write(&self, key: String, value: Vec<u8>) -> Result<()> {
72        if !self.is_alive.load(Ordering::SeqCst) {
73            return Err(AuroraError::InvalidOperation(
74                "Write buffer is not active.".into(),
75            ));
76        }
77        self.sender.send(WriteOp { key, value }).map_err(|_| {
78            AuroraError::InvalidOperation("Write buffer channel closed unexpectedly.".into())
79        })?;
80        Ok(())
81    }
82
83    pub fn is_active(&self) -> bool {
84        self.is_alive.load(Ordering::SeqCst)
85    }
86}
87
88impl Drop for WriteBuffer {
89    fn drop(&mut self) {}
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use tempfile::tempdir;
96
97    #[tokio::test]
98    async fn test_write_buffer() -> Result<()> {
99        let temp_dir = tempdir()?;
100        let db_path = temp_dir.path().join("test.db");
101        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
102
103        let buffer = WriteBuffer::new(Arc::clone(&cold), 100, 50);
104
105        buffer.write("key1".to_string(), b"value1".to_vec())?;
106        buffer.write("key2".to_string(), b"value2".to_vec())?;
107        buffer.write("key3".to_string(), b"value3".to_vec())?;
108
109        tokio::time::sleep(Duration::from_millis(100)).await;
110
111        assert_eq!(cold.get("key1")?, Some(b"value1".to_vec()));
112        assert_eq!(cold.get("key2")?, Some(b"value2".to_vec()));
113        assert_eq!(cold.get("key3")?, Some(b"value3".to_vec()));
114
115        Ok(())
116    }
117
118    #[tokio::test]
119    async fn test_write_buffer_batch_flush() -> Result<()> {
120        let temp_dir = tempdir()?;
121        let db_path = temp_dir.path().join("test.db");
122        let cold = Arc::new(ColdStore::new(db_path.to_str().unwrap())?);
123
124        let buffer = WriteBuffer::new(Arc::clone(&cold), 5, 1000);
125
126        for i in 0..10 {
127            buffer.write(format!("key{}", i), format!("value{}", i).into_bytes())?;
128        }
129
130        tokio::time::sleep(Duration::from_millis(50)).await;
131
132        for i in 0..10 {
133            assert_eq!(
134                cold.get(&format!("key{}", i))?,
135                Some(format!("value{}", i).into_bytes())
136            );
137        }
138
139        Ok(())
140    }
141}