1use std::{
2 collections::HashMap,
3 sync::{atomic::Ordering, Arc},
4};
5
6use bytes::{BufMut, Bytes, BytesMut};
7use parking_lot::Mutex;
8use prost::{decode_length_delimiter, encode_length_delimiter};
9
10use crate::{
11 data::log_record::{LogRecord, LogRecordType},
12 db::Engine,
13 errors::{Errors, Result},
14 option::{IndexType, WriteBatchOptions},
15};
16
17const TXN_FIN_KEY: &[u8] = "txn-fin".as_bytes();
18pub(crate) const NON_TXN_SEQ_NO: usize = 0;
19
20pub struct WriteBatch<'a> {
22 pending_writes: Arc<Mutex<HashMap<Vec<u8>, LogRecord>>>, engine: &'a Engine,
24 options: WriteBatchOptions,
25}
26
27impl Engine {
28 pub fn new_write_batch(&self, options: WriteBatchOptions) -> Result<WriteBatch> {
30 if self.options.index_type == IndexType::BPlusTree && !self.seq_file_exists && !self.is_initial
31 {
32 return Err(Errors::UnableToUseWriteBatch);
33 }
34
35 Ok(WriteBatch {
36 pending_writes: Arc::new(Mutex::new(HashMap::new())),
37 engine: self,
38 options,
39 })
40 }
41}
42
43impl WriteBatch<'_> {
44 pub fn put(&self, key: Bytes, value: Bytes) -> Result<()> {
46 if key.is_empty() {
47 return Err(Errors::KeyIsEmpty);
48 }
49
50 let record = LogRecord {
52 key: key.to_vec(),
53 value: value.to_vec(),
54 rec_type: LogRecordType::Normal,
55 };
56
57 let mut pending_writes = self.pending_writes.lock();
58 pending_writes.insert(key.to_vec(), record);
59 Ok(())
60 }
61
62 pub fn delete(&self, key: Bytes) -> Result<()> {
63 if key.is_empty() {
64 return Err(Errors::KeyIsEmpty);
65 }
66
67 let mut pending_writes = self.pending_writes.lock();
68 let index_pos = self.engine.index.get(key.to_vec());
70 if index_pos.is_none() {
71 if pending_writes.contains_key(&key.to_vec()) {
72 pending_writes.remove(&key.to_vec());
73 }
74 return Ok(());
75 }
76
77 let record = LogRecord {
79 key: key.to_vec(),
80 value: Default::default(),
81 rec_type: LogRecordType::Deleted,
82 };
83 pending_writes.insert(key.to_vec(), record);
84 Ok(())
85 }
86
87 pub fn commit(&self) -> Result<()> {
89 let mut pending_writes = self.pending_writes.lock();
90 if pending_writes.len() == 0 {
91 return Ok(());
92 }
93 if pending_writes.len() > self.options.max_batch_num {
94 return Err(Errors::ExceedMaxBatchNum);
95 }
96
97 let _lock = self.engine.batch_commit_lock.lock();
99
100 let seq_no = self.engine.seq_no.fetch_add(1, Ordering::SeqCst);
102
103 let mut positions = HashMap::new();
104 for (_, item) in pending_writes.iter() {
106 let mut record = LogRecord {
107 key: log_record_key_with_seq(item.key.clone(), seq_no),
108 value: item.value.clone(),
109 rec_type: item.rec_type,
110 };
111
112 let pos = self.engine.append_log_record(&mut record)?;
113 positions.insert(item.key.clone(), pos);
114 }
115
116 let mut finish_record = LogRecord {
118 key: log_record_key_with_seq(TXN_FIN_KEY.to_vec(), seq_no),
119 value: Default::default(),
120 rec_type: LogRecordType::TxnFinished,
121 };
122
123 self.engine.append_log_record(&mut finish_record)?;
125 if self.options.sync_writes {
126 self.engine.sync()?;
127 }
128
129 for (_, item) in pending_writes.iter() {
131 let record_pos = positions.get(&item.key).unwrap();
132 if item.rec_type == LogRecordType::Normal {
133 if let Some(old_pos) = self.engine.index.put(item.key.clone(), *record_pos) {
134 self
135 .engine
136 .reclaim_size
137 .fetch_add(old_pos.size as usize, Ordering::SeqCst);
138 }
139 }
140 if item.rec_type == LogRecordType::Deleted {
141 if let Some(old_pos) = self.engine.index.delete(item.key.clone()) {
142 self
143 .engine
144 .reclaim_size
145 .fetch_add(old_pos.size as usize, Ordering::SeqCst);
146 }
147 }
148 }
149
150 pending_writes.clear();
152
153 Ok(())
154 }
155}
156
157pub(crate) fn log_record_key_with_seq(key: Vec<u8>, seq_no: usize) -> Vec<u8> {
159 let mut enc_key = BytesMut::new();
160 encode_length_delimiter(seq_no, &mut enc_key).unwrap();
161 enc_key.extend_from_slice(&key.to_vec());
162 enc_key.to_vec()
163}
164
165pub(crate) fn parse_log_record_key(key: Vec<u8>) -> (Vec<u8>, usize) {
167 let mut buf = BytesMut::new();
168 buf.put_slice(&key);
169 let seq_no = decode_length_delimiter(&mut buf).unwrap();
170 (buf.to_vec(), seq_no)
171}
172
173#[cfg(test)]
174mod tests {
175 use std::path::PathBuf;
176
177 use crate::{
178 option::Options,
179 util::rand_kv::{get_test_key, get_test_value},
180 };
181
182 use super::*;
183
184 #[test]
185 fn test_write_batch_1() {
186 let mut opt = Options::default();
187 opt.dir_path = PathBuf::from("/tmp/bitkv-rs-batch-1");
188 opt.data_file_size = 64 * 1024 * 1024; let engine = Engine::open(opt.clone()).expect("fail to open engine");
190
191 let wb = engine
192 .new_write_batch(WriteBatchOptions::default())
193 .expect("fail to create write batch");
194
195 let put_res1 = wb.put(get_test_key(1), get_test_value(10));
197 assert!(put_res1.is_ok());
198 let put_res2 = wb.put(get_test_key(2), get_test_value(20));
199 assert!(put_res2.is_ok());
200
201 let res1 = engine.get(get_test_key(1));
202 assert_eq!(Errors::KeyNotFound, res1.err().unwrap());
203
204 let commit_res = wb.commit();
206 assert!(commit_res.is_ok());
207 let res2 = engine.get(get_test_key(1));
208 assert_eq!(get_test_value(10), res2.unwrap());
209
210 let seq_no = wb.engine.seq_no.load(Ordering::SeqCst);
212 assert_eq!(2, seq_no);
213
214 std::fs::remove_dir_all(opt.clone().dir_path).expect("failed to remove dir");
216 }
217
218 #[test]
219 fn test_write_batch_2() {
220 let mut opt = Options::default();
221 opt.dir_path = PathBuf::from("/tmp/bitkv-rs-batch-2");
222 opt.data_file_size = 64 * 1024 * 1024; let engine = Engine::open(opt.clone()).expect("fail to open engine");
224
225 let wb = engine
226 .new_write_batch(WriteBatchOptions::default())
227 .expect("fail to create write batch");
228
229 let put_res1 = wb.put(get_test_key(1), get_test_value(10));
230 assert!(put_res1.is_ok());
231 let put_res2 = wb.put(get_test_key(2), get_test_value(20));
232 assert!(put_res2.is_ok());
233 let commit_res1 = wb.commit();
234 assert!(commit_res1.is_ok());
235
236 let put_res3 = wb.put(get_test_key(3), get_test_value(10));
237 assert!(put_res3.is_ok());
238 let commit_res2 = wb.commit();
239 assert!(commit_res2.is_ok());
240
241 engine.close().expect("fail to close");
243 std::mem::drop(engine);
244
245 let engine2 = Engine::open(opt.clone()).expect("fail to open engine");
246 let keys = engine2.list_keys();
247 assert_eq!(3, keys.unwrap().len());
248 let seq_no = engine2.seq_no.load(Ordering::SeqCst);
249 assert_eq!(3, seq_no);
250
251 std::fs::remove_dir_all(opt.clone().dir_path).expect("failed to remove dir");
253 }
254
255 #[test]
256 fn test_write_batch_3() {
257 let mut opt = Options::default();
258 opt.dir_path = PathBuf::from("/tmp/bitkv-rs-batch-2");
259 opt.data_file_size = 64 * 1024 * 1024; let engine = Engine::open(opt.clone()).expect("fail to open engine");
261
262 let mut wb_opts = WriteBatchOptions::default();
263 wb_opts.max_batch_num = 10000000;
264 let wb = engine
265 .new_write_batch(wb_opts)
266 .expect("fail to create write batch");
267
268 for i in 0..=1000000 {
269 let put_res = wb.put(get_test_key(i), get_test_value(i));
270 assert!(put_res.is_ok());
271 }
272
273 let commit_res1 = wb.commit();
274 assert!(commit_res1.is_ok());
275
276 std::fs::remove_dir_all(opt.clone().dir_path).expect("failed to remove dir");
278 }
279}