Skip to main content

hickory_server/store/sqlite/
persistence.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry -@- me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! All zone persistence related types
9
10use std::iter::Iterator;
11use std::path::Path;
12use std::sync::{Mutex, MutexGuard};
13
14use rusqlite::types::ToSql;
15use rusqlite::{self, Connection};
16use thiserror::Error;
17use time;
18use tracing::error;
19
20use crate::proto::ProtoError;
21use crate::proto::rr::Record;
22use crate::proto::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder};
23
24/// The current Journal version of the application
25pub const CURRENT_VERSION: i64 = 1;
26
27/// The Journal is the audit log of all changes to a zone after initial creation.
28pub struct Journal {
29    conn: Mutex<Connection>,
30    version: i64,
31}
32
33impl Journal {
34    /// Constructs a new Journal, attaching to the specified Sqlite Connection
35    pub fn new(conn: Connection) -> Result<Self, PersistenceError> {
36        let version = Self::select_schema_version(&conn)?;
37        Ok(Self {
38            conn: Mutex::new(conn),
39            version,
40        })
41    }
42
43    /// Constructs a new Journal opening a Sqlite connection to the file at the specified path
44    pub fn from_file(journal_file: &Path) -> Result<Self, PersistenceError> {
45        let result = Self::new(Connection::open(journal_file)?);
46        let mut journal = result?;
47        journal.schema_up()?;
48        Ok(journal)
49    }
50
51    /// Returns a reference to the Sqlite Connection
52    pub fn conn(&self) -> MutexGuard<'_, Connection> {
53        self.conn.lock().expect("conn poisoned")
54    }
55
56    /// Returns the current schema version of the journal
57    pub fn schema_version(&self) -> i64 {
58        self.version
59    }
60
61    /// this returns an iterator from the beginning of time, to be used to recreate a zone handler
62    pub fn iter(&self) -> JournalIter<'_> {
63        JournalIter::new(self)
64    }
65
66    /// Inserts a record, this is an append only operation.
67    ///
68    /// Records should never be posthumously modified. The first message serialized to the journal
69    /// should be a single AXFR of the entire zone to be used as a starting point for reconstruction.
70    ///
71    /// # Argument
72    ///
73    /// * `record` - will be serialized into the journal
74    pub fn insert_record(&self, soa_serial: u32, record: &Record) -> Result<(), PersistenceError> {
75        assert!(
76            self.version == CURRENT_VERSION,
77            "schema version mismatch, schema_up() resolves this"
78        );
79
80        let mut serial_record: Vec<u8> = Vec::with_capacity(512);
81        {
82            let mut encoder = BinEncoder::new(&mut serial_record);
83            record.emit(&mut encoder)?;
84        }
85
86        let timestamp = time::OffsetDateTime::now_utc();
87        let client_id: i64 = 0; // TODO: we need better id information about the client, like pub_key
88        let soa_serial: i64 = i64::from(soa_serial);
89
90        let count = self.conn.lock().expect("conn poisoned").execute(
91            "INSERT
92                                          \
93                                            INTO records (client_id, soa_serial, timestamp, \
94                                            record)
95                                          \
96                                            VALUES ($1, $2, $3, $4)",
97            [
98                &client_id as &dyn ToSql,
99                &soa_serial,
100                &timestamp,
101                &serial_record,
102            ],
103        )?;
104        //
105        if count != 1 {
106            return Err(PersistenceError::WrongInsertCount {
107                got: count,
108                expect: 1,
109            });
110        };
111
112        Ok(())
113    }
114
115    /// Inserts a set of records into the Journal, a convenience method for insert_record
116    pub fn insert_records(
117        &self,
118        soa_serial: u32,
119        records: &[Record],
120    ) -> Result<(), PersistenceError> {
121        // TODO: NEED TRANSACTION HERE
122        for record in records {
123            self.insert_record(soa_serial, record)?;
124        }
125
126        Ok(())
127    }
128
129    /// Selects a record from the given row_id.
130    ///
131    /// This allows for the entire set of records to be iterated through, by starting at 0, and
132    ///  incrementing each subsequent row.
133    ///
134    /// # Arguments
135    ///
136    /// * `row_id` - the row_id can either be exact, or start at 0 to get the earliest row in the
137    ///   list.
138    pub fn select_record(&self, row_id: i64) -> Result<Option<(i64, Record)>, PersistenceError> {
139        assert!(
140            self.version == CURRENT_VERSION,
141            "schema version mismatch, schema_up() resolves this"
142        );
143
144        let conn = self.conn.lock().expect("conn poisoned");
145        let mut stmt = conn.prepare(
146            "SELECT _rowid_, record
147                                            \
148                                               FROM records
149                                            \
150                                               WHERE _rowid_ >= $1
151                                            \
152                                               LIMIT 1",
153        )?;
154
155        let record_opt: Option<Result<(i64, Record), rusqlite::Error>> = stmt
156            .query_and_then([&row_id], |row| -> Result<(i64, Record), rusqlite::Error> {
157                let row_id: i64 = row.get(0)?;
158                let record_bytes: Vec<u8> = row.get(1)?;
159                let mut decoder = BinDecoder::new(&record_bytes);
160
161                // todo add location to this...
162                match Record::read(&mut decoder) {
163                    Ok(record) => Ok((row_id, record)),
164                    Err(decode_error) => Err(rusqlite::Error::InvalidParameterName(format!(
165                        "could not decode: {decode_error}"
166                    ))),
167                }
168            })?
169            .next();
170
171        //
172        match record_opt {
173            Some(Ok((row_id, record))) => Ok(Some((row_id, record))),
174            Some(Err(err)) => Err(err.into()),
175            None => Ok(None),
176        }
177    }
178
179    /// selects the current schema version of the journal DB, returns -1 if there is no schema
180    ///
181    ///
182    /// # Arguments
183    ///
184    /// * `conn` - db connection to use
185    pub fn select_schema_version(conn: &Connection) -> Result<i64, PersistenceError> {
186        // first see if our schema is there
187        let mut stmt = conn.prepare(
188            "SELECT name
189                                        \
190                                          FROM sqlite_master
191                                        \
192                                          WHERE type='table'
193                                        \
194                                          AND name='tdns_schema'",
195        )?;
196
197        let tdns_schema_opt: Option<Result<String, _>> =
198            stmt.query_map([], |row| row.get(0))?.next();
199
200        let tdns_schema = match tdns_schema_opt {
201            Some(Ok(string)) => string,
202            Some(Err(err)) => return Err(err.into()),
203            None => return Ok(-1),
204        };
205
206        assert_eq!(&tdns_schema, "tdns_schema");
207
208        let version: i64 = conn.query_row(
209            "SELECT version
210                                            \
211                                                FROM tdns_schema",
212            [],
213            |row| row.get(0),
214        )?;
215
216        Ok(version)
217    }
218
219    /// update the schema version
220    fn update_schema_version(&self, new_version: i64) -> Result<(), PersistenceError> {
221        // validate the versions of all the schemas...
222        assert!(new_version <= CURRENT_VERSION);
223
224        let count = self
225            .conn
226            .lock()
227            .expect("conn poisoned")
228            .execute("UPDATE tdns_schema SET version = $1", [&new_version])?;
229
230        //
231        assert_eq!(count, 1);
232        Ok(())
233    }
234
235    /// initializes the schema for the Journal
236    pub fn schema_up(&mut self) -> Result<i64, PersistenceError> {
237        while self.version < CURRENT_VERSION {
238            match self.version + 1 {
239                0 => self.version = self.init_up()?,
240                1 => self.version = self.records_up()?,
241                _ => panic!("incorrect version somewhere"), // valid panic, non-recoverable state
242            }
243
244            self.update_schema_version(self.version)?;
245        }
246
247        Ok(self.version)
248    }
249
250    /// initial schema, include the tdns_schema table for tracking the Journal version
251    fn init_up(&self) -> Result<i64, PersistenceError> {
252        let count = self.conn.lock().expect("conn poisoned").execute(
253            "CREATE TABLE tdns_schema (
254                                          \
255                                            version INTEGER NOT NULL
256                                        \
257                                            )",
258            [],
259        )?;
260        //
261        assert_eq!(count, 0);
262
263        let count = self
264            .conn
265            .lock()
266            .expect("conn poisoned")
267            .execute("INSERT INTO tdns_schema (version) VALUES (0)", [])?;
268        //
269        assert_eq!(count, 1);
270
271        Ok(0)
272    }
273
274    /// adds the records table, this is the main and single table for the history of changes to a
275    ///  zone. Each record is expected to be in the format of an update record
276    fn records_up(&self) -> Result<i64, PersistenceError> {
277        // we'll be using rowid for our primary key, basically: `rowid INTEGER PRIMARY KEY ASC`
278        let count = self.conn.lock().expect("conn poisoned").execute(
279            "CREATE TABLE records (
280                                          \
281                                            client_id      INTEGER NOT NULL,
282                                          \
283                                            soa_serial     INTEGER NOT NULL,
284                                          \
285                                            timestamp      TEXT NOT NULL,
286                                          \
287                                            record         BLOB NOT NULL
288                                        \
289                                            )",
290            [],
291        )?;
292        //
293        assert_eq!(count, 1);
294
295        Ok(1)
296    }
297}
298
299/// Returns an iterator over all items in a Journal
300///
301/// Useful for replaying an entire journal into memory to reconstruct a zone from disk
302pub struct JournalIter<'j> {
303    current_row_id: i64,
304    journal: &'j Journal,
305}
306
307impl<'j> JournalIter<'j> {
308    fn new(journal: &'j Journal) -> Self {
309        JournalIter {
310            current_row_id: 0,
311            journal,
312        }
313    }
314}
315
316impl Iterator for JournalIter<'_> {
317    type Item = Record;
318
319    fn next(&mut self) -> Option<Self::Item> {
320        match self.journal.select_record(self.current_row_id + 1) {
321            Ok(Some((row_id, record))) => {
322                self.current_row_id = row_id;
323                Some(record)
324            }
325            Ok(None) => None,
326            Err(error) => {
327                error!(%error, "persistence error while iterating over journal");
328                None
329            }
330        }
331    }
332}
333
334/// The error kind for errors that get returned in the crate
335#[derive(Debug, Error)]
336#[non_exhaustive]
337pub enum PersistenceError {
338    /// An error that occurred when recovering from journal
339    #[error("error recovering from journal: {}", _0)]
340    Recovery(&'static str),
341
342    /// The number of inserted records didn't match the expected amount
343    #[error("wrong insert count: {} expect: {}", got, expect)]
344    WrongInsertCount {
345        /// The number of inserted records
346        got: usize,
347        /// The number of records expected to be inserted
348        expect: usize,
349    },
350
351    // foreign
352    /// An error got returned by the hickory-proto crate
353    #[error("proto error: {0}")]
354    Proto(#[from] ProtoError),
355
356    /// An error got returned from the sqlite crate
357    #[cfg(feature = "sqlite")]
358    #[error("sqlite error: {0}")]
359    Sqlite(#[from] rusqlite::Error),
360
361    /// A request timed out
362    #[error("request timed out")]
363    Timeout,
364}