hickory_server/store/sqlite/
persistence.rs1use std::iter::Iterator;
11use std::path::Path;
12use std::sync::{Mutex, MutexGuard};
13
14use rusqlite::types::ToSql;
15use rusqlite::{self, Connection};
16use thiserror::Error;
17use time;
18use tracing::error;
19
20use crate::proto::ProtoError;
21use crate::proto::rr::Record;
22use crate::proto::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder};
23
24pub const CURRENT_VERSION: i64 = 1;
26
27pub struct Journal {
29 conn: Mutex<Connection>,
30 version: i64,
31}
32
33impl Journal {
34 pub fn new(conn: Connection) -> Result<Self, PersistenceError> {
36 let version = Self::select_schema_version(&conn)?;
37 Ok(Self {
38 conn: Mutex::new(conn),
39 version,
40 })
41 }
42
43 pub fn from_file(journal_file: &Path) -> Result<Self, PersistenceError> {
45 let result = Self::new(Connection::open(journal_file)?);
46 let mut journal = result?;
47 journal.schema_up()?;
48 Ok(journal)
49 }
50
51 pub fn conn(&self) -> MutexGuard<'_, Connection> {
53 self.conn.lock().expect("conn poisoned")
54 }
55
56 pub fn schema_version(&self) -> i64 {
58 self.version
59 }
60
61 pub fn iter(&self) -> JournalIter<'_> {
63 JournalIter::new(self)
64 }
65
66 pub fn insert_record(&self, soa_serial: u32, record: &Record) -> Result<(), PersistenceError> {
75 assert!(
76 self.version == CURRENT_VERSION,
77 "schema version mismatch, schema_up() resolves this"
78 );
79
80 let mut serial_record: Vec<u8> = Vec::with_capacity(512);
81 {
82 let mut encoder = BinEncoder::new(&mut serial_record);
83 record.emit(&mut encoder)?;
84 }
85
86 let timestamp = time::OffsetDateTime::now_utc();
87 let client_id: i64 = 0; let soa_serial: i64 = i64::from(soa_serial);
89
90 let count = self.conn.lock().expect("conn poisoned").execute(
91 "INSERT
92 \
93 INTO records (client_id, soa_serial, timestamp, \
94 record)
95 \
96 VALUES ($1, $2, $3, $4)",
97 [
98 &client_id as &dyn ToSql,
99 &soa_serial,
100 ×tamp,
101 &serial_record,
102 ],
103 )?;
104 if count != 1 {
106 return Err(PersistenceError::WrongInsertCount {
107 got: count,
108 expect: 1,
109 });
110 };
111
112 Ok(())
113 }
114
115 pub fn insert_records(
117 &self,
118 soa_serial: u32,
119 records: &[Record],
120 ) -> Result<(), PersistenceError> {
121 for record in records {
123 self.insert_record(soa_serial, record)?;
124 }
125
126 Ok(())
127 }
128
129 pub fn select_record(&self, row_id: i64) -> Result<Option<(i64, Record)>, PersistenceError> {
139 assert!(
140 self.version == CURRENT_VERSION,
141 "schema version mismatch, schema_up() resolves this"
142 );
143
144 let conn = self.conn.lock().expect("conn poisoned");
145 let mut stmt = conn.prepare(
146 "SELECT _rowid_, record
147 \
148 FROM records
149 \
150 WHERE _rowid_ >= $1
151 \
152 LIMIT 1",
153 )?;
154
155 let record_opt: Option<Result<(i64, Record), rusqlite::Error>> = stmt
156 .query_and_then([&row_id], |row| -> Result<(i64, Record), rusqlite::Error> {
157 let row_id: i64 = row.get(0)?;
158 let record_bytes: Vec<u8> = row.get(1)?;
159 let mut decoder = BinDecoder::new(&record_bytes);
160
161 match Record::read(&mut decoder) {
163 Ok(record) => Ok((row_id, record)),
164 Err(decode_error) => Err(rusqlite::Error::InvalidParameterName(format!(
165 "could not decode: {decode_error}"
166 ))),
167 }
168 })?
169 .next();
170
171 match record_opt {
173 Some(Ok((row_id, record))) => Ok(Some((row_id, record))),
174 Some(Err(err)) => Err(err.into()),
175 None => Ok(None),
176 }
177 }
178
179 pub fn select_schema_version(conn: &Connection) -> Result<i64, PersistenceError> {
186 let mut stmt = conn.prepare(
188 "SELECT name
189 \
190 FROM sqlite_master
191 \
192 WHERE type='table'
193 \
194 AND name='tdns_schema'",
195 )?;
196
197 let tdns_schema_opt: Option<Result<String, _>> =
198 stmt.query_map([], |row| row.get(0))?.next();
199
200 let tdns_schema = match tdns_schema_opt {
201 Some(Ok(string)) => string,
202 Some(Err(err)) => return Err(err.into()),
203 None => return Ok(-1),
204 };
205
206 assert_eq!(&tdns_schema, "tdns_schema");
207
208 let version: i64 = conn.query_row(
209 "SELECT version
210 \
211 FROM tdns_schema",
212 [],
213 |row| row.get(0),
214 )?;
215
216 Ok(version)
217 }
218
219 fn update_schema_version(&self, new_version: i64) -> Result<(), PersistenceError> {
221 assert!(new_version <= CURRENT_VERSION);
223
224 let count = self
225 .conn
226 .lock()
227 .expect("conn poisoned")
228 .execute("UPDATE tdns_schema SET version = $1", [&new_version])?;
229
230 assert_eq!(count, 1);
232 Ok(())
233 }
234
235 pub fn schema_up(&mut self) -> Result<i64, PersistenceError> {
237 while self.version < CURRENT_VERSION {
238 match self.version + 1 {
239 0 => self.version = self.init_up()?,
240 1 => self.version = self.records_up()?,
241 _ => panic!("incorrect version somewhere"), }
243
244 self.update_schema_version(self.version)?;
245 }
246
247 Ok(self.version)
248 }
249
250 fn init_up(&self) -> Result<i64, PersistenceError> {
252 let count = self.conn.lock().expect("conn poisoned").execute(
253 "CREATE TABLE tdns_schema (
254 \
255 version INTEGER NOT NULL
256 \
257 )",
258 [],
259 )?;
260 assert_eq!(count, 0);
262
263 let count = self
264 .conn
265 .lock()
266 .expect("conn poisoned")
267 .execute("INSERT INTO tdns_schema (version) VALUES (0)", [])?;
268 assert_eq!(count, 1);
270
271 Ok(0)
272 }
273
274 fn records_up(&self) -> Result<i64, PersistenceError> {
277 let count = self.conn.lock().expect("conn poisoned").execute(
279 "CREATE TABLE records (
280 \
281 client_id INTEGER NOT NULL,
282 \
283 soa_serial INTEGER NOT NULL,
284 \
285 timestamp TEXT NOT NULL,
286 \
287 record BLOB NOT NULL
288 \
289 )",
290 [],
291 )?;
292 assert_eq!(count, 1);
294
295 Ok(1)
296 }
297}
298
299pub struct JournalIter<'j> {
303 current_row_id: i64,
304 journal: &'j Journal,
305}
306
307impl<'j> JournalIter<'j> {
308 fn new(journal: &'j Journal) -> Self {
309 JournalIter {
310 current_row_id: 0,
311 journal,
312 }
313 }
314}
315
316impl Iterator for JournalIter<'_> {
317 type Item = Record;
318
319 fn next(&mut self) -> Option<Self::Item> {
320 match self.journal.select_record(self.current_row_id + 1) {
321 Ok(Some((row_id, record))) => {
322 self.current_row_id = row_id;
323 Some(record)
324 }
325 Ok(None) => None,
326 Err(error) => {
327 error!(%error, "persistence error while iterating over journal");
328 None
329 }
330 }
331 }
332}
333
334#[derive(Debug, Error)]
336#[non_exhaustive]
337pub enum PersistenceError {
338 #[error("error recovering from journal: {}", _0)]
340 Recovery(&'static str),
341
342 #[error("wrong insert count: {} expect: {}", got, expect)]
344 WrongInsertCount {
345 got: usize,
347 expect: usize,
349 },
350
351 #[error("proto error: {0}")]
354 Proto(#[from] ProtoError),
355
356 #[cfg(feature = "sqlite")]
358 #[error("sqlite error: {0}")]
359 Sqlite(#[from] rusqlite::Error),
360
361 #[error("request timed out")]
363 Timeout,
364}