balius_runtime/
store.rs

1use itertools::Itertools;
2use prost::Message;
3use redb::{ReadableTable as _, TableDefinition, WriteTransaction};
4use std::{path::Path, sync::Arc};
5use tracing::warn;
6
7use crate::{Block, ChainPoint, Error};
8
9pub type WorkerId = String;
10pub type LogSeq = u64;
11
12#[derive(Message)]
13pub struct LogEntry {
14    #[prost(bytes, tag = "1")]
15    pub next_block: Vec<u8>,
16    #[prost(bytes, repeated, tag = "2")]
17    pub undo_blocks: Vec<Vec<u8>>,
18}
19
20impl redb::Value for LogEntry {
21    type SelfType<'a>
22        = LogEntry
23    where
24        Self: 'a;
25
26    type AsBytes<'a>
27        = Vec<u8>
28    where
29        Self: 'a;
30
31    fn fixed_width() -> Option<usize> {
32        None
33    }
34
35    fn from_bytes<'a>(data: &'a [u8]) -> Self::SelfType<'a>
36    where
37        Self: 'a,
38    {
39        prost::Message::decode(data).unwrap()
40    }
41
42    fn as_bytes<'a, 'b: 'a>(value: &'a Self::SelfType<'b>) -> Self::AsBytes<'a>
43    where
44        Self: 'a,
45        Self: 'b,
46    {
47        value.encode_to_vec()
48    }
49
50    fn type_name() -> redb::TypeName {
51        redb::TypeName::new("LogEntry")
52    }
53}
54
55const CURSORS: TableDefinition<WorkerId, LogSeq> = TableDefinition::new("cursors");
56const WAL: TableDefinition<LogSeq, LogEntry> = TableDefinition::new("wal");
57
58const DEFAULT_CACHE_SIZE_MB: usize = 50;
59
60pub struct AtomicUpdate {
61    wx: WriteTransaction,
62    log_seq: LogSeq,
63}
64
65impl AtomicUpdate {
66    pub fn update_worker_cursor(&mut self, id: &str) -> Result<(), super::Error> {
67        let mut table = self.wx.open_table(CURSORS)?;
68        table.insert(id.to_owned(), self.log_seq)?;
69
70        Ok(())
71    }
72
73    pub fn commit(self) -> Result<(), super::Error> {
74        self.wx.commit()?;
75        Ok(())
76    }
77}
78
79#[derive(Clone)]
80pub struct Store {
81    db: Arc<redb::Database>,
82    log_seq: LogSeq,
83}
84
85impl Store {
86    pub fn in_memory() -> Result<Self, super::Error> {
87        let db = Arc::new(
88            redb::Database::builder().create_with_backend(redb::backends::InMemoryBackend::new())?,
89        );
90        Ok(Self { db, log_seq: 0 })
91    }
92
93    pub fn open(path: impl AsRef<Path>, cache_size: Option<usize>) -> Result<Self, super::Error> {
94        let inner = redb::Database::builder()
95            .set_repair_callback(|x| {
96                warn!(progress = x.progress() * 100f64, "balius db is repairing")
97            })
98            .set_cache_size(1024 * 1024 * cache_size.unwrap_or(DEFAULT_CACHE_SIZE_MB))
99            .create(path)?;
100
101        let log_seq = Self::load_log_seq(&inner)?.unwrap_or_default();
102
103        let out = Self {
104            db: Arc::new(inner),
105            log_seq,
106        };
107
108        Ok(out)
109    }
110
111    pub fn into_ephemeral(&mut self) -> Result<Self, super::Error> {
112        let new_db =
113            redb::Database::builder().create_with_backend(redb::backends::InMemoryBackend::new())?;
114
115        let rx = self.db.begin_read()?;
116        let wx = new_db.begin_write()?;
117
118        {
119            if let Ok(source) = rx.open_table(WAL) {
120                let mut target = wx.open_table(WAL)?;
121
122                for entry in source.iter()? {
123                    let (k, v) = entry?;
124                    target.insert(k.value(), v.value())?;
125                }
126            }
127
128            if let Ok(source) = rx.open_table(CURSORS) {
129                let mut target = wx.open_table(CURSORS)?;
130
131                for entry in source.iter()? {
132                    let (k, v) = entry?;
133                    target.insert(k.value(), v.value())?;
134                }
135            }
136        }
137
138        wx.commit()?;
139
140        let log_seq = Self::load_log_seq(&new_db)?.unwrap_or_default();
141        let new = Store {
142            db: Arc::new(new_db),
143            log_seq,
144        };
145
146        Ok(new)
147    }
148
149    fn load_log_seq(db: &redb::Database) -> Result<Option<LogSeq>, Error> {
150        let rx = db.begin_read()?;
151
152        match rx.open_table(WAL) {
153            Ok(table) => {
154                let last = table.last()?;
155                Ok(last.map(|(k, _)| k.value()))
156            }
157            Err(redb::TableError::TableDoesNotExist(_)) => Ok(None),
158            Err(e) => Err(e.into()),
159        }
160    }
161
162    fn get_entry(&self, seq: LogSeq) -> Result<Option<LogEntry>, Error> {
163        let rx = self.db.begin_read()?;
164        let table = rx.open_table(WAL)?;
165        let entry = table.get(seq)?;
166        Ok(entry.map(|x| x.value()))
167    }
168
169    pub fn find_chain_point(&self, seq: LogSeq) -> Result<Option<ChainPoint>, Error> {
170        let entry = self.get_entry(seq)?;
171        let block = Block::from_bytes(&entry.unwrap().next_block);
172
173        Ok(Some(block.chain_point()))
174    }
175
176    pub fn write_ahead(
177        &mut self,
178        undo_blocks: &[Block],
179        next_block: &Block,
180    ) -> Result<LogSeq, Error> {
181        self.log_seq += 1;
182
183        let wx = self.db.begin_write()?;
184        {
185            wx.open_table(WAL)?.insert(
186                self.log_seq,
187                LogEntry {
188                    next_block: next_block.to_bytes(),
189                    undo_blocks: undo_blocks.iter().map(|x| x.to_bytes()).collect(),
190                },
191            )?;
192        }
193
194        wx.commit()?;
195        Ok(self.log_seq)
196    }
197
198    // TODO: see if loading in batch is worth it
199    pub fn get_worker_cursor(&self, id: &str) -> Result<Option<LogSeq>, super::Error> {
200        let rx = self.db.begin_read()?;
201
202        let table = match rx.open_table(CURSORS) {
203            Ok(table) => table,
204            Err(redb::TableError::TableDoesNotExist(_)) => return Ok(None),
205            Err(e) => return Err(e.into()),
206        };
207
208        let cursor = table.get(id.to_owned())?;
209        Ok(cursor.map(|x| x.value()))
210    }
211
212    pub fn start_atomic_update(&self, log_seq: LogSeq) -> Result<AtomicUpdate, super::Error> {
213        let wx = self.db.begin_write()?;
214        Ok(AtomicUpdate { wx, log_seq })
215    }
216
217    // TODO: I don't think we need this since we're going to load each cursor as
218    // part of the loaded worker
219    pub fn lowest_cursor(&self) -> Result<Option<LogSeq>, super::Error> {
220        let rx = self.db.begin_read()?;
221
222        let table = rx.open_table(CURSORS)?;
223
224        let cursors: Vec<_> = table
225            .iter()?
226            .map_ok(|(_, value)| value.value())
227            .try_collect()?;
228
229        let lowest = cursors.iter().fold(None, |all, item| all.min(Some(*item)));
230
231        Ok(lowest)
232    }
233}