1use std::path::Path;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10
11use bytes::Bytes;
12use exoware_server::StoreEngine;
13use rocksdb::{ColumnFamily, ColumnFamilyDescriptor, Direction, IteratorMode, Options, DB};
14
15const SEQ_META_KEY: &[u8] = b"__simulator_seq__";
17const BATCH_LOG_CF: &str = "batch_log";
18
19#[derive(Clone)]
22pub struct RocksStore {
23 db: Arc<DB>,
24 sequence: Arc<AtomicU64>,
25 observer: Option<Arc<AtomicU64>>,
27}
28
29impl RocksStore {
30 pub fn open(path: &Path) -> Result<Self, rocksdb::Error> {
31 Self::open_with_observer(path, None)
32 }
33
34 pub fn open_with_observer(
35 path: &Path,
36 observer: Option<Arc<AtomicU64>>,
37 ) -> Result<Self, rocksdb::Error> {
38 let mut opts = Options::default();
39 opts.create_if_missing(true);
40 opts.create_missing_column_families(true);
41
42 let cf_default =
43 ColumnFamilyDescriptor::new(rocksdb::DEFAULT_COLUMN_FAMILY_NAME, Options::default());
44 let cf_batch_log = ColumnFamilyDescriptor::new(BATCH_LOG_CF, Options::default());
45 let db = Arc::new(DB::open_cf_descriptors(
46 &opts,
47 path,
48 vec![cf_default, cf_batch_log],
49 )?);
50 let seq = match db.get(SEQ_META_KEY)? {
51 Some(bytes) if bytes.len() == 8 => u64::from_le_bytes(bytes.try_into().unwrap()),
52 _ => 0,
53 };
54 Ok(Self {
55 db,
56 sequence: Arc::new(AtomicU64::new(seq)),
57 observer,
58 })
59 }
60
61 fn batch_log_cf(&self) -> &ColumnFamily {
62 self.db
63 .cf_handle(BATCH_LOG_CF)
64 .expect("batch_log CF must exist (created on open)")
65 }
66
67 fn batch_put_rocksdb(&self, kvs: &[(Bytes, Bytes)]) -> Result<u64, rocksdb::Error> {
68 let next = self.sequence.fetch_add(1, Ordering::SeqCst) + 1;
69 let encoded = encode_batch_entries(kvs);
70 let mut batch = rocksdb::WriteBatch::default();
71 for (k, v) in kvs {
72 batch.put(k.as_ref(), v.as_ref());
73 }
74 batch.put(SEQ_META_KEY, next.to_le_bytes());
75 batch.put_cf(self.batch_log_cf(), next.to_be_bytes(), &encoded);
76 self.db.write(batch)?;
77 if let Some(obs) = &self.observer {
78 obs.store(next, Ordering::SeqCst);
79 }
80 Ok(next)
81 }
82
83 fn get_rocksdb(&self, key: &[u8]) -> Result<Option<Vec<u8>>, rocksdb::Error> {
84 if key == SEQ_META_KEY {
85 return Ok(None);
86 }
87 self.db.get(key)
88 }
89
90 fn range_scan_rocksdb(
92 &self,
93 start: &[u8],
94 end: &[u8],
95 limit: usize,
96 forward: bool,
97 ) -> Result<Vec<(Bytes, Bytes)>, rocksdb::Error> {
98 if limit == 0 {
99 return Ok(Vec::new());
100 }
101 let mode = IteratorMode::From(start, Direction::Forward);
102 let mut tmp = Vec::new();
103 for item in self.db.iterator(mode) {
104 let (k, v) = item?;
105 if k.as_ref() == SEQ_META_KEY {
106 continue;
107 }
108 if k.as_ref() < start {
109 continue;
110 }
111 if !end.is_empty() && k.as_ref() > end {
112 break;
113 }
114 tmp.push((
115 Bytes::copy_from_slice(k.as_ref()),
116 Bytes::copy_from_slice(&v),
117 ));
118 }
119 if tmp.is_empty() {
120 return Ok(tmp);
121 }
122 if forward {
123 tmp.truncate(limit);
124 return Ok(tmp);
125 }
126 if tmp.len() > limit {
127 tmp = tmp.split_off(tmp.len() - limit);
128 }
129 tmp.reverse();
130 Ok(tmp)
131 }
132}
133
134impl StoreEngine for RocksStore {
135 fn put_batch(&self, kvs: &[(Bytes, Bytes)]) -> Result<u64, String> {
136 self.batch_put_rocksdb(kvs).map_err(|e| e.to_string())
137 }
138
139 fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, String> {
140 self.get_rocksdb(key).map_err(|e| e.to_string())
141 }
142
143 fn range_scan(
144 &self,
145 start: &[u8],
146 end: &[u8],
147 limit: usize,
148 forward: bool,
149 ) -> Result<Vec<(Bytes, Bytes)>, String> {
150 self.range_scan_rocksdb(start, end, limit, forward)
151 .map_err(|e| e.to_string())
152 }
153
154 fn get_many(&self, keys: &[&[u8]]) -> Result<Vec<(Vec<u8>, Option<Vec<u8>>)>, String> {
155 let results = self.db.multi_get(keys);
156 keys.iter()
157 .zip(results)
158 .map(|(k, r)| {
159 if *k == SEQ_META_KEY {
160 return Ok((k.to_vec(), None));
161 }
162 let value = r.map_err(|e| e.to_string())?;
163 Ok((k.to_vec(), value))
164 })
165 .collect()
166 }
167
168 fn delete_batch(&self, keys: &[&[u8]]) -> Result<u64, String> {
169 let next = self.sequence.fetch_add(1, Ordering::SeqCst) + 1;
170 let mut batch = rocksdb::WriteBatch::default();
171 for k in keys {
172 batch.delete(k);
173 }
174 batch.put(SEQ_META_KEY, next.to_le_bytes());
175 batch.put_cf(
180 self.batch_log_cf(),
181 next.to_be_bytes(),
182 encode_batch_entries(&[]),
183 );
184 self.db.write(batch).map_err(|e| e.to_string())?;
185 if let Some(obs) = &self.observer {
186 obs.store(next, Ordering::SeqCst);
187 }
188 Ok(next)
189 }
190
191 fn current_sequence(&self) -> u64 {
192 self.sequence.load(Ordering::SeqCst)
193 }
194
195 fn get_batch(&self, sequence_number: u64) -> Result<Option<Vec<(Bytes, Bytes)>>, String> {
196 let cf = self.batch_log_cf();
197 match self
198 .db
199 .get_cf(cf, sequence_number.to_be_bytes())
200 .map_err(|e| e.to_string())?
201 {
202 Some(raw) => Ok(Some(decode_batch_entries(&raw).map_err(|e| e.to_string())?)),
203 None => Ok(None),
204 }
205 }
206
207 fn oldest_retained_batch(&self) -> Result<Option<u64>, String> {
208 let cf = self.batch_log_cf();
209 let mut it = self.db.iterator_cf(cf, IteratorMode::Start);
210 match it.next() {
211 None => Ok(None),
212 Some(item) => {
213 let (key, _) = item.map_err(|e| e.to_string())?;
214 if key.len() != 8 {
215 return Err(format!(
216 "batch_log CF key has unexpected length {}",
217 key.len()
218 ));
219 }
220 let mut buf = [0u8; 8];
221 buf.copy_from_slice(key.as_ref());
222 Ok(Some(u64::from_be_bytes(buf)))
223 }
224 }
225 }
226
227 fn prune_batch_log(&self, cutoff_exclusive: u64) -> Result<u64, String> {
228 let cf = self.batch_log_cf();
232 let end_key = cutoff_exclusive.to_be_bytes();
233 let mut deleted = 0u64;
234 let mut batch = rocksdb::WriteBatch::default();
235 let iter = self.db.iterator_cf(cf, IteratorMode::Start);
236 for item in iter {
237 let (k, _) = item.map_err(|e| e.to_string())?;
238 if k.as_ref() >= &end_key[..] {
239 break;
240 }
241 batch.delete_cf(cf, k.as_ref());
242 deleted += 1;
243 }
244 if deleted > 0 {
245 self.db.write(batch).map_err(|e| e.to_string())?;
246 }
247 Ok(deleted)
248 }
249}
250
251fn encode_batch_entries(kvs: &[(Bytes, Bytes)]) -> Vec<u8> {
254 let mut size = 4;
255 for (k, v) in kvs {
256 size += 4 + k.len() + 4 + v.len();
257 }
258 let mut out = Vec::with_capacity(size);
259 out.extend_from_slice(&(kvs.len() as u32).to_be_bytes());
260 for (k, v) in kvs {
261 out.extend_from_slice(&(k.len() as u32).to_be_bytes());
262 out.extend_from_slice(k.as_ref());
263 out.extend_from_slice(&(v.len() as u32).to_be_bytes());
264 out.extend_from_slice(v.as_ref());
265 }
266 out
267}
268
269fn decode_batch_entries(mut raw: &[u8]) -> Result<Vec<(Bytes, Bytes)>, String> {
270 fn take_u32(buf: &mut &[u8]) -> Result<u32, String> {
271 if buf.len() < 4 {
272 return Err("batch log truncated at u32 header".to_string());
273 }
274 let (head, rest) = buf.split_at(4);
275 *buf = rest;
276 let mut raw = [0u8; 4];
277 raw.copy_from_slice(head);
278 Ok(u32::from_be_bytes(raw))
279 }
280 fn take_n<'a>(buf: &mut &'a [u8], n: usize) -> Result<&'a [u8], String> {
281 if buf.len() < n {
282 return Err("batch log truncated at payload".to_string());
283 }
284 let (head, rest) = buf.split_at(n);
285 *buf = rest;
286 Ok(head)
287 }
288 let n = take_u32(&mut raw)? as usize;
289 let mut out = Vec::with_capacity(n);
290 for _ in 0..n {
291 let klen = take_u32(&mut raw)? as usize;
292 let k = Bytes::copy_from_slice(take_n(&mut raw, klen)?);
293 let vlen = take_u32(&mut raw)? as usize;
294 let v = Bytes::copy_from_slice(take_n(&mut raw, vlen)?);
295 out.push((k, v));
296 }
297 if !raw.is_empty() {
298 return Err(format!(
299 "batch log had {} trailing bytes after decode",
300 raw.len()
301 ));
302 }
303 Ok(out)
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn batch_entries_codec_round_trip() {
312 let kvs = vec![
313 (Bytes::from_static(b"a"), Bytes::from_static(b"1")),
314 (Bytes::from_static(b""), Bytes::from_static(b"empty key ok")),
315 (
316 Bytes::from_static(b"binary\x00\xff"),
317 Bytes::from_static(&[0u8, 1, 2, 3]),
318 ),
319 ];
320 let encoded = encode_batch_entries(&kvs);
321 let decoded = decode_batch_entries(&encoded).unwrap();
322 assert_eq!(decoded, kvs);
323 }
324
325 #[test]
326 fn empty_batch_round_trips() {
327 let encoded = encode_batch_entries(&[]);
328 let decoded = decode_batch_entries(&encoded).unwrap();
329 assert!(decoded.is_empty());
330 }
331}