1use itertools::Itertools;
2use prost::Message;
3use redb::{ReadableTable as _, TableDefinition, WriteTransaction};
4use std::{path::Path, sync::Arc};
5use tracing::warn;
6
7use crate::{Block, ChainPoint, Error};
8
9pub type WorkerId = String;
10pub type LogSeq = u64;
11
12#[derive(Message)]
13pub struct LogEntry {
14 #[prost(bytes, tag = "1")]
15 pub next_block: Vec<u8>,
16 #[prost(bytes, repeated, tag = "2")]
17 pub undo_blocks: Vec<Vec<u8>>,
18}
19
20impl redb::Value for LogEntry {
21 type SelfType<'a>
22 = LogEntry
23 where
24 Self: 'a;
25
26 type AsBytes<'a>
27 = Vec<u8>
28 where
29 Self: 'a;
30
31 fn fixed_width() -> Option<usize> {
32 None
33 }
34
35 fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
36 where
37 Self: 'a,
38 {
39 prost::Message::decode(data).unwrap()
40 }
41
42 fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
43 where
44 Self: 'a,
45 Self: 'b,
46 {
47 value.encode_to_vec()
48 }
49
50 fn type_name() -> redb::TypeName {
51 redb::TypeName::new("LogEntry")
52 }
53}
54
55const CURSORS: TableDefinition<WorkerId, LogSeq> = TableDefinition::new("cursors");
56const WAL: TableDefinition<LogSeq, LogEntry> = TableDefinition::new("wal");
57
58const DEFAULT_CACHE_SIZE_MB: usize = 50;
59
60pub struct AtomicUpdate {
61 wx: WriteTransaction,
62 log_seq: LogSeq,
63}
64
65impl AtomicUpdate {
66 pub fn update_worker_cursor(&mut self, id: &str) -> Result<(), super::Error> {
67 let mut table = self.wx.open_table(CURSORS)?;
68 table.insert(id.to_owned(), self.log_seq)?;
69
70 Ok(())
71 }
72
73 pub fn commit(self) -> Result<(), super::Error> {
74 self.wx.commit()?;
75 Ok(())
76 }
77}
78
79#[derive(Clone)]
80pub struct Store {
81 db: Arc<redb::Database>,
82 log_seq: LogSeq,
83}
84
85impl Store {
86 pub fn in_memory() -> Result<Self, super::Error> {
87 let db = Arc::new(
88 redb::Database::builder().create_with_backend(redb::backends::InMemoryBackend::new())?,
89 );
90 Ok(Self { db, log_seq: 0 })
91 }
92
93 pub fn open(path: impl AsRef<Path>, cache_size: Option<usize>) -> Result<Self, super::Error> {
94 let inner = redb::Database::builder()
95 .set_repair_callback(|x| {
96 warn!(progress = x.progress() * 100f64, "balius db is repairing")
97 })
98 .set_cache_size(1024 * 1024 * cache_size.unwrap_or(DEFAULT_CACHE_SIZE_MB))
99 .create(path)?;
100
101 let log_seq = Self::load_log_seq(&inner)?.unwrap_or_default();
102
103 let out = Self {
104 db: Arc::new(inner),
105 log_seq,
106 };
107
108 Ok(out)
109 }
110
111 fn load_log_seq(db: &redb::Database) -> Result<Option<LogSeq>, Error> {
112 let rx = db.begin_read()?;
113
114 match rx.open_table(WAL) {
115 Ok(table) => {
116 let last = table.last()?;
117 Ok(last.map(|(k, _)| k.value()))
118 }
119 Err(redb::TableError::TableDoesNotExist(_)) => Ok(None),
120 Err(e) => Err(e.into()),
121 }
122 }
123
124 fn get_entry(&self, seq: LogSeq) -> Result<Option<LogEntry>, Error> {
125 let rx = self.db.begin_read()?;
126 let table = rx.open_table(WAL)?;
127 let entry = table.get(seq)?;
128 Ok(entry.map(|x| x.value()))
129 }
130
131 pub fn find_chain_point(&self, seq: LogSeq) -> Result<Option<ChainPoint>, Error> {
132 let entry = self.get_entry(seq)?;
133 let block = Block::from_bytes(&entry.unwrap().next_block);
134
135 Ok(Some(block.chain_point()))
136 }
137
138 pub fn write_ahead(
139 &mut self,
140 undo_blocks: &[Block],
141 next_block: &Block,
142 ) -> Result<LogSeq, Error> {
143 self.log_seq += 1;
144
145 let wx = self.db.begin_write()?;
146 {
147 wx.open_table(WAL)?.insert(
148 self.log_seq,
149 LogEntry {
150 next_block: next_block.to_bytes(),
151 undo_blocks: undo_blocks.iter().map(|x| x.to_bytes()).collect(),
152 },
153 )?;
154 }
155
156 wx.commit()?;
157 Ok(self.log_seq)
158 }
159
160 pub fn get_worker_cursor(&self, id: &str) -> Result<Option<LogSeq>, super::Error> {
162 let rx = self.db.begin_read()?;
163
164 let table = match rx.open_table(CURSORS) {
165 Ok(table) => table,
166 Err(redb::TableError::TableDoesNotExist(_)) => return Ok(None),
167 Err(e) => return Err(e.into()),
168 };
169
170 let cursor = table.get(id.to_owned())?;
171 Ok(cursor.map(|x| x.value()))
172 }
173
174 pub fn start_atomic_update(&self, log_seq: LogSeq) -> Result<AtomicUpdate, super::Error> {
175 let wx = self.db.begin_write()?;
176 Ok(AtomicUpdate { wx, log_seq })
177 }
178
179 pub fn lowest_cursor(&self) -> Result<Option<LogSeq>, super::Error> {
182 let rx = self.db.begin_read()?;
183
184 let table = rx.open_table(CURSORS)?;
185
186 let cursors: Vec<_> = table
187 .iter()?
188 .map_ok(|(_, value)| value.value())
189 .try_collect()?;
190
191 let lowest = cursors.iter().fold(None, |all, item| all.min(Some(*item)));
192
193 Ok(lowest)
194 }
195}