kyoto/db/sqlite/
headers.rs

1use std::collections::{BTreeMap, HashSet};
2use std::fs;
3use std::ops::{Bound, RangeBounds};
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use bitcoin::block::Header;
8use bitcoin::{consensus, BlockHash, Network};
9use rusqlite::{params, params_from_iter, Connection, Result};
10use tokio::sync::Mutex;
11
12use crate::db::error::{SqlHeaderStoreError, SqlInitializationError};
13use crate::db::traits::HeaderStore;
14use crate::db::BlockHeaderChanges;
15use crate::prelude::FutureResult;
16
17use super::{DATA_DIR, DEFAULT_CWD};
18
19const FILE_NAME: &str = "headers.db";
20// Labels for the schema table
21const SCHEMA_TABLE_NAME: &str = "header_schema_versions";
22const SCHEMA_COLUMN: &str = "schema_key";
23const VERSION_COLUMN: &str = "version";
24const SCHEMA_KEY: &str = "current_version";
25// Update this in the case of schema changes
26const SCHEMA_VERSION: u8 = 0;
27// Always execute this query and adjust the schema with migrations
28const INITIAL_HEADER_SCHEMA: &str = "CREATE TABLE IF NOT EXISTS headers (
29    height INTEGER PRIMARY KEY,
30    block_hash BLOB NOT NULL,
31    header BLOB NOT NULL
32) STRICT";
33
34const LOAD_QUERY_SELECT_PREFIX: &str = "SELECT * FROM headers ";
35const LOAD_QUERY_ORDERBY_SUFFIX: &str = "ORDER BY height";
36
37/// Header storage implementation with SQL Lite.
38#[derive(Debug)]
39pub struct SqliteHeaderDb {
40    conn: Arc<Mutex<Connection>>,
41    accepted: BTreeMap<u32, Header>,
42    disconnected: HashSet<BlockHash>,
43}
44
45impl SqliteHeaderDb {
46    /// Create a new [`SqliteHeaderDb`] with an optional file path. If no path is provided,
47    /// the file will be stored in a `data` subdirectory where the program is ran.
48    pub fn new(network: Network, path: Option<PathBuf>) -> Result<Self, SqlInitializationError> {
49        let mut path = path.unwrap_or_else(|| PathBuf::from(DEFAULT_CWD));
50        path.push(DATA_DIR);
51        path.push(network.to_string());
52        if !path.exists() {
53            fs::create_dir_all(&path)?;
54        }
55        let conn = Connection::open(path.join(FILE_NAME))?;
56        // Create the schema version
57        let schema_table_query = format!(
58            "CREATE TABLE IF NOT EXISTS {SCHEMA_TABLE_NAME} ({SCHEMA_COLUMN} TEXT PRIMARY KEY, {VERSION_COLUMN} INTEGER NOT NULL)");
59        // Update the schema version
60        conn.execute(&schema_table_query, [])?;
61        let schema_init_version = format!(
62            "INSERT OR REPLACE INTO {SCHEMA_TABLE_NAME} ({SCHEMA_COLUMN}, {VERSION_COLUMN}) VALUES (?1, ?2)");
63        conn.execute(&schema_init_version, params![SCHEMA_KEY, SCHEMA_VERSION])?;
64        // Build the table if it doesn't exist
65        conn.execute(INITIAL_HEADER_SCHEMA, [])?;
66        // Migrate to any new schema versions
67        Self::migrate(&conn)?;
68
69        Ok(Self {
70            conn: Arc::new(Mutex::new(conn)),
71            accepted: BTreeMap::new(),
72            disconnected: HashSet::new(),
73        })
74    }
75
76    // This function currently does nothing, but if new columns are required this may be used to alter the tables
77    // without breaking older tables.
78    fn migrate(conn: &Connection) -> Result<(), SqlInitializationError> {
79        let version_query =
80            format!("SELECT {VERSION_COLUMN} FROM {SCHEMA_TABLE_NAME} WHERE {SCHEMA_COLUMN} = ?1");
81        let _current_version: u8 =
82            conn.query_row(&version_query, [SCHEMA_KEY], |row| row.get(0))?;
83        // Match on the version and migrate to new schemas in the future
84        Ok(())
85    }
86
87    async fn load<'a>(
88        &mut self,
89        range: impl RangeBounds<u32> + Send + Sync + 'a,
90    ) -> Result<BTreeMap<u32, Header>, SqlHeaderStoreError> {
91        let mut param_list = Vec::new();
92        let mut stmt = LOAD_QUERY_SELECT_PREFIX.to_string();
93
94        match range.start_bound() {
95            Bound::Unbounded => {
96                stmt.push_str("WHERE height >= 0 ");
97            }
98            Bound::Included(h) => {
99                stmt.push_str("WHERE height >= ? ");
100                param_list.push(*h);
101            }
102            Bound::Excluded(h) => {
103                stmt.push_str("WHERE height > ? ");
104                param_list.push(*h);
105            }
106        };
107
108        match range.end_bound() {
109            Bound::Unbounded => (),
110            Bound::Included(h) => {
111                stmt.push_str("AND height <= ? ");
112                param_list.push(*h);
113            }
114            Bound::Excluded(h) => {
115                stmt.push_str("AND height < ? ");
116                param_list.push(*h);
117            }
118        };
119
120        stmt.push_str(LOAD_QUERY_ORDERBY_SUFFIX);
121
122        let mut headers = BTreeMap::<u32, Header>::new();
123        let write_lock = self.conn.lock().await;
124        let mut query = write_lock.prepare(&stmt)?;
125        let mut rows = query.query(params_from_iter(param_list.iter()))?;
126        while let Some(row) = rows.next()? {
127            let height: u32 = row.get(0)?;
128            let header: [u8; 80] = row.get(2)?;
129            let next_header: Header = consensus::deserialize(&header)?;
130            if let Some(header) = headers.values().last() {
131                if header.block_hash().ne(&next_header.prev_blockhash) {
132                    return Err(SqlHeaderStoreError::Corruption);
133                }
134            }
135            headers.insert(height, next_header);
136        }
137        Ok(headers)
138    }
139
140    fn stage(&mut self, changes: BlockHeaderChanges) {
141        match changes {
142            BlockHeaderChanges::Connected(indexed_header) => {
143                self.accepted
144                    .insert(indexed_header.height, indexed_header.header);
145            }
146            BlockHeaderChanges::Reorganized {
147                accepted,
148                reorganized,
149            } => {
150                for indexed_header in reorganized {
151                    let removed_hash = indexed_header.header.block_hash();
152                    self.accepted
153                        .retain(|_, header| header.block_hash().ne(&removed_hash));
154                    self.disconnected.insert(removed_hash);
155                }
156                for indexed_header in accepted {
157                    self.accepted
158                        .insert(indexed_header.height, indexed_header.header);
159                }
160            }
161        }
162    }
163
164    async fn write(&mut self) -> Result<(), SqlHeaderStoreError> {
165        let mut write_lock = self.conn.lock().await;
166        let tx = write_lock.transaction()?;
167        for removed in core::mem::take(&mut self.disconnected) {
168            let hash: Vec<u8> = consensus::serialize(&removed);
169            let stmt = "DELETE FROM headers WHERE block_hash = ?1";
170            tx.execute(stmt, params![hash])?;
171        }
172        for (height, header) in core::mem::take(&mut self.accepted) {
173            let hash: Vec<u8> = consensus::serialize(&header.block_hash());
174            let header: Vec<u8> = consensus::serialize(&header);
175            let stmt =
176                "INSERT OR REPLACE INTO headers (height, block_hash, header) VALUES (?1, ?2, ?3)";
177            tx.execute(stmt, params![height, hash, header])?;
178        }
179        tx.commit()?;
180        Ok(())
181    }
182
183    async fn height_of(
184        &mut self,
185        block_hash: &BlockHash,
186    ) -> Result<Option<u32>, SqlHeaderStoreError> {
187        let write_lock = self.conn.lock().await;
188        let stmt = "SELECT height FROM headers WHERE block_hash = ?1";
189        let hash: Vec<u8> = consensus::serialize(&block_hash);
190        let row: Option<u32> = write_lock.query_row(stmt, params![hash], |row| row.get(0))?;
191        Ok(row)
192    }
193
194    async fn hash_at(&mut self, height: u32) -> Result<Option<BlockHash>, SqlHeaderStoreError> {
195        let write_lock = self.conn.lock().await;
196        let stmt = "SELECT block_hash FROM headers WHERE height = ?1";
197        let row: Option<[u8; 32]> =
198            write_lock.query_row(stmt, params![height], |row| row.get(0))?;
199        match row {
200            Some(hash) => Ok(Some(consensus::deserialize(&hash)?)),
201            None => Ok(None),
202        }
203    }
204
205    async fn header_at(&mut self, height: u32) -> Result<Option<Header>, SqlHeaderStoreError> {
206        let write_lock = self.conn.lock().await;
207        let stmt = "SELECT * FROM headers WHERE height = ?1";
208        let query = write_lock.query_row(stmt, params![height], |row| {
209            let header_slice: [u8; 80] = row.get(2)?;
210            consensus::deserialize(&header_slice)
211                .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))
212        });
213        match query {
214            Ok(header) => Ok(Some(header)),
215            Err(e) => match e {
216                rusqlite::Error::QueryReturnedNoRows => Ok(None),
217                _ => Err(SqlHeaderStoreError::SQL(e)),
218            },
219        }
220    }
221}
222
223impl HeaderStore for SqliteHeaderDb {
224    type Error = SqlHeaderStoreError;
225
226    fn load<'a>(
227        &'a mut self,
228        range: impl RangeBounds<u32> + Send + Sync + 'a,
229    ) -> FutureResult<'a, BTreeMap<u32, Header>, Self::Error> {
230        Box::pin(self.load(range))
231    }
232
233    fn stage(&mut self, changes: BlockHeaderChanges) {
234        self.stage(changes)
235    }
236
237    fn write(&mut self) -> FutureResult<'_, (), Self::Error> {
238        Box::pin(self.write())
239    }
240
241    fn height_of<'a>(
242        &'a mut self,
243        hash: &'a BlockHash,
244    ) -> FutureResult<'a, Option<u32>, Self::Error> {
245        Box::pin(self.height_of(hash))
246    }
247
248    fn hash_at(&mut self, height: u32) -> FutureResult<'_, Option<BlockHash>, Self::Error> {
249        Box::pin(self.hash_at(height))
250    }
251
252    fn header_at(&mut self, height: u32) -> FutureResult<'_, Option<Header>, Self::Error> {
253        Box::pin(self.header_at(height))
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use crate::chain::IndexedHeader;
260
261    use super::*;
262    use bitcoin::consensus::deserialize;
263
264    #[tokio::test]
265    async fn test_sql_header_store_normal_use() {
266        let binding = tempfile::tempdir().unwrap();
267        let path = binding.path();
268        let mut db = SqliteHeaderDb::new(Network::Regtest, Some(path.into())).unwrap();
269        let block_8: Header = deserialize(&hex::decode("0000002016fe292517eecbbd63227d126a6b1db30ebc5262c61f8f3a4a529206388fc262dfd043cef8454f71f30b5bbb9eb1a4c9aea87390f429721e435cf3f8aa6e2a9171375166ffff7f2000000000").unwrap()).unwrap();
270        let block_9: Header = deserialize(&hex::decode("000000205708a90197d93475975545816b2229401ccff7567cb23900f14f2bd46732c605fd8de19615a1d687e89db365503cdf58cb649b8e935a1d3518fa79b0d408704e71375166ffff7f2000000000").unwrap()).unwrap();
271        let block_10: Header = deserialize(&hex::decode("000000201d062f2162835787db536c55317e08df17c58078c7610328bdced198574093790c9f554a7780a6043a19619d2a4697364bb62abf6336c0568c31f1eedca3c3e171375166ffff7f2000000000").unwrap()).unwrap();
272        let changes_8 = IndexedHeader::new(8, block_8);
273        let changes_9 = IndexedHeader::new(9, block_9);
274        let changes_10 = IndexedHeader::new(10, block_10);
275        let mut map = BTreeMap::new();
276        map.insert(8, block_8);
277        map.insert(9, block_9);
278        map.insert(10, block_10);
279        let block_hash_8 = block_8.block_hash();
280        let block_hash_9 = block_9.block_hash();
281        db.stage(BlockHeaderChanges::Connected(changes_8));
282        db.stage(BlockHeaderChanges::Connected(changes_9));
283        db.stage(BlockHeaderChanges::Connected(changes_10));
284        let w = db.write().await;
285        assert!(w.is_ok());
286        let get_hash_9 = db.hash_at(9).await.unwrap().unwrap();
287        assert_eq!(get_hash_9, block_hash_9);
288        let get_height_8 = db.height_of(&block_hash_8).await.unwrap().unwrap();
289        assert_eq!(get_height_8, 8);
290        let load = db.load(7..).await.unwrap();
291
292        assert_eq!(map, load);
293        let get_header_9 = db.header_at(9).await.unwrap().unwrap();
294        assert_eq!(get_header_9, block_9);
295        let get_header_11 = db.header_at(11).await.unwrap();
296        assert!(get_header_11.is_none());
297        let get_header_7 = db.header_at(7).await.unwrap();
298        assert!(get_header_7.is_none());
299        drop(db);
300        binding.close().unwrap();
301    }
302
303    #[tokio::test]
304    async fn test_sql_header_loads_with_fork() {
305        let binding = tempfile::tempdir().unwrap();
306        let path = binding.path();
307        let mut db = SqliteHeaderDb::new(Network::Regtest, Some(path.into())).unwrap();
308        let block_8: Header = deserialize(&hex::decode("0000002016fe292517eecbbd63227d126a6b1db30ebc5262c61f8f3a4a529206388fc262dfd043cef8454f71f30b5bbb9eb1a4c9aea87390f429721e435cf3f8aa6e2a9171375166ffff7f2000000000").unwrap()).unwrap();
309        let block_9: Header = deserialize(&hex::decode("000000205708a90197d93475975545816b2229401ccff7567cb23900f14f2bd46732c605fd8de19615a1d687e89db365503cdf58cb649b8e935a1d3518fa79b0d408704e71375166ffff7f2000000000").unwrap()).unwrap();
310        let block_10: Header = deserialize(&hex::decode("000000201d062f2162835787db536c55317e08df17c58078c7610328bdced198574093790c9f554a7780a6043a19619d2a4697364bb62abf6336c0568c31f1eedca3c3e171375166ffff7f2000000000").unwrap()).unwrap();
311        let mut map = BTreeMap::new();
312        map.insert(8, block_8);
313        map.insert(9, block_9);
314        map.insert(10, block_10);
315        let changes_8 = IndexedHeader::new(8, block_8);
316        let changes_9 = IndexedHeader::new(9, block_9);
317        let changes_10 = IndexedHeader::new(10, block_10);
318        db.stage(BlockHeaderChanges::Connected(changes_8));
319        db.stage(BlockHeaderChanges::Connected(changes_9));
320        db.stage(BlockHeaderChanges::Connected(changes_10));
321        let w = db.write().await;
322        assert!(w.is_ok());
323        let get_height_10 = db.header_at(10).await.unwrap().unwrap();
324        assert_eq!(block_10, get_height_10);
325        let new_block_10: Header = deserialize(&hex::decode("000000201d062f2162835787db536c55317e08df17c58078c7610328bdced198574093792151c0e9ce4e4c789ca98427d7740cc7acf30d2ca0c08baef266bf152289d814567e5e66ffff7f2001000000").unwrap()).unwrap();
326        let block_11: Header = deserialize(&hex::decode("00000020efcf8b12221fccc735b9b0b657ce15b31b9c50aff530ce96a5b4cfe02d8c0068496c1b8a89cf5dec22e46c35ea1035f80f5b666a1b3aa7f3d6f0880d0061adcc567e5e66ffff7f2001000000").unwrap()).unwrap();
327        let mut map = BTreeMap::new();
328        map.insert(10, new_block_10);
329        map.insert(11, block_11);
330        let accepted = vec![
331            IndexedHeader::new(10, new_block_10),
332            IndexedHeader::new(11, block_11),
333        ];
334        let reorganized = vec![IndexedHeader::new(10, block_10)];
335        db.stage(BlockHeaderChanges::Reorganized {
336            accepted,
337            reorganized,
338        });
339        let w = db.write().await;
340        assert!(w.is_ok());
341        let block_hash_11 = block_11.block_hash();
342        let block_hash_10 = new_block_10.block_hash();
343        let get_height_10 = db.header_at(10).await.unwrap().unwrap();
344        assert_eq!(new_block_10, get_height_10);
345        let get_height_12 = db.header_at(12).await.unwrap();
346        assert!(get_height_12.is_none());
347        let get_hash_10 = db.hash_at(10).await.unwrap().unwrap();
348        assert_eq!(get_hash_10, block_hash_10);
349        let get_height_11 = db.height_of(&block_hash_11).await.unwrap().unwrap();
350        assert_eq!(get_height_11, 11);
351        let mut map = BTreeMap::new();
352        map.insert(8, block_8);
353        map.insert(9, block_9);
354        map.insert(10, new_block_10);
355        map.insert(11, block_11);
356        let load = db.load(7..).await.unwrap();
357        assert_eq!(map, load);
358        drop(db);
359        binding.close().unwrap();
360    }
361
362    #[tokio::test]
363    async fn test_range_loads_properly() {
364        let binding = tempfile::tempdir().unwrap();
365        let path = binding.path();
366        let mut db = SqliteHeaderDb::new(Network::Regtest, Some(path.into())).unwrap();
367        let block_8: Header = deserialize(&hex::decode("0000002016fe292517eecbbd63227d126a6b1db30ebc5262c61f8f3a4a529206388fc262dfd043cef8454f71f30b5bbb9eb1a4c9aea87390f429721e435cf3f8aa6e2a9171375166ffff7f2000000000").unwrap()).unwrap();
368        let block_9: Header = deserialize(&hex::decode("000000205708a90197d93475975545816b2229401ccff7567cb23900f14f2bd46732c605fd8de19615a1d687e89db365503cdf58cb649b8e935a1d3518fa79b0d408704e71375166ffff7f2000000000").unwrap()).unwrap();
369        let block_10: Header = deserialize(&hex::decode("000000201d062f2162835787db536c55317e08df17c58078c7610328bdced198574093790c9f554a7780a6043a19619d2a4697364bb62abf6336c0568c31f1eedca3c3e171375166ffff7f2000000000").unwrap()).unwrap();
370        let mut map = BTreeMap::new();
371        map.insert(8, block_8);
372        map.insert(9, block_9);
373        map.insert(10, block_10);
374        let changes_8 = IndexedHeader::new(8, block_8);
375        let changes_9 = IndexedHeader::new(9, block_9);
376        let changes_10 = IndexedHeader::new(10, block_10);
377        db.stage(BlockHeaderChanges::Connected(changes_8));
378        db.stage(BlockHeaderChanges::Connected(changes_9));
379        db.stage(BlockHeaderChanges::Connected(changes_10));
380        let w = db.write().await;
381        assert!(w.is_ok());
382        let load = db.load(7..).await.unwrap();
383        assert_eq!(map, load);
384        let load = db.load(8..).await.unwrap();
385        assert_eq!(map, load);
386        let load = db.load(8..10).await.unwrap();
387        map.remove(&10);
388        assert_eq!(map, load);
389        let load = db.load(..10).await.unwrap();
390        assert_eq!(map, load);
391        drop(db);
392        binding.close().unwrap();
393    }
394}