use std::iter::Iterator;
use std::path::Path;
use std::sync::{Mutex, MutexGuard};
use rusqlite::types::ToSql;
use rusqlite::{self, Connection};
use thiserror::Error;
use time;
use tracing::error;
use crate::proto::ProtoError;
use crate::proto::rr::Record;
use crate::proto::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder};
pub const CURRENT_VERSION: i64 = 1;
pub struct Journal {
conn: Mutex<Connection>,
version: i64,
}
impl Journal {
pub fn new(conn: Connection) -> Result<Self, PersistenceError> {
let version = Self::select_schema_version(&conn)?;
Ok(Self {
conn: Mutex::new(conn),
version,
})
}
pub fn from_file(journal_file: &Path) -> Result<Self, PersistenceError> {
let result = Self::new(Connection::open(journal_file)?);
let mut journal = result?;
journal.schema_up()?;
Ok(journal)
}
pub fn conn(&self) -> MutexGuard<'_, Connection> {
self.conn.lock().expect("conn poisoned")
}
pub fn schema_version(&self) -> i64 {
self.version
}
pub fn iter(&self) -> JournalIter<'_> {
JournalIter::new(self)
}
pub fn insert_record(&self, soa_serial: u32, record: &Record) -> Result<(), PersistenceError> {
assert!(
self.version == CURRENT_VERSION,
"schema version mismatch, schema_up() resolves this"
);
let mut serial_record: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut serial_record);
record.emit(&mut encoder)?;
}
let timestamp = time::OffsetDateTime::now_utc();
let client_id: i64 = 0; let soa_serial: i64 = i64::from(soa_serial);
let count = self.conn.lock().expect("conn poisoned").execute(
"INSERT
\
INTO records (client_id, soa_serial, timestamp, \
record)
\
VALUES ($1, $2, $3, $4)",
[
&client_id as &dyn ToSql,
&soa_serial,
×tamp,
&serial_record,
],
)?;
if count != 1 {
return Err(PersistenceError::WrongInsertCount {
got: count,
expect: 1,
});
};
Ok(())
}
pub fn insert_records(
&self,
soa_serial: u32,
records: &[Record],
) -> Result<(), PersistenceError> {
for record in records {
self.insert_record(soa_serial, record)?;
}
Ok(())
}
pub fn select_record(&self, row_id: i64) -> Result<Option<(i64, Record)>, PersistenceError> {
assert!(
self.version == CURRENT_VERSION,
"schema version mismatch, schema_up() resolves this"
);
let conn = self.conn.lock().expect("conn poisoned");
let mut stmt = conn.prepare(
"SELECT _rowid_, record
\
FROM records
\
WHERE _rowid_ >= $1
\
LIMIT 1",
)?;
let record_opt: Option<Result<(i64, Record), rusqlite::Error>> = stmt
.query_and_then([&row_id], |row| -> Result<(i64, Record), rusqlite::Error> {
let row_id: i64 = row.get(0)?;
let record_bytes: Vec<u8> = row.get(1)?;
let mut decoder = BinDecoder::new(&record_bytes);
match Record::read(&mut decoder) {
Ok(record) => Ok((row_id, record)),
Err(decode_error) => Err(rusqlite::Error::InvalidParameterName(format!(
"could not decode: {decode_error}"
))),
}
})?
.next();
match record_opt {
Some(Ok((row_id, record))) => Ok(Some((row_id, record))),
Some(Err(err)) => Err(err.into()),
None => Ok(None),
}
}
pub fn select_schema_version(conn: &Connection) -> Result<i64, PersistenceError> {
let mut stmt = conn.prepare(
"SELECT name
\
FROM sqlite_master
\
WHERE type='table'
\
AND name='tdns_schema'",
)?;
let tdns_schema_opt: Option<Result<String, _>> =
stmt.query_map([], |row| row.get(0))?.next();
let tdns_schema = match tdns_schema_opt {
Some(Ok(string)) => string,
Some(Err(err)) => return Err(err.into()),
None => return Ok(-1),
};
assert_eq!(&tdns_schema, "tdns_schema");
let version: i64 = conn.query_row(
"SELECT version
\
FROM tdns_schema",
[],
|row| row.get(0),
)?;
Ok(version)
}
fn update_schema_version(&self, new_version: i64) -> Result<(), PersistenceError> {
assert!(new_version <= CURRENT_VERSION);
let count = self
.conn
.lock()
.expect("conn poisoned")
.execute("UPDATE tdns_schema SET version = $1", [&new_version])?;
assert_eq!(count, 1);
Ok(())
}
pub fn schema_up(&mut self) -> Result<i64, PersistenceError> {
while self.version < CURRENT_VERSION {
match self.version + 1 {
0 => self.version = self.init_up()?,
1 => self.version = self.records_up()?,
_ => panic!("incorrect version somewhere"), }
self.update_schema_version(self.version)?;
}
Ok(self.version)
}
fn init_up(&self) -> Result<i64, PersistenceError> {
let count = self.conn.lock().expect("conn poisoned").execute(
"CREATE TABLE tdns_schema (
\
version INTEGER NOT NULL
\
)",
[],
)?;
assert_eq!(count, 0);
let count = self
.conn
.lock()
.expect("conn poisoned")
.execute("INSERT INTO tdns_schema (version) VALUES (0)", [])?;
assert_eq!(count, 1);
Ok(0)
}
fn records_up(&self) -> Result<i64, PersistenceError> {
let count = self.conn.lock().expect("conn poisoned").execute(
"CREATE TABLE records (
\
client_id INTEGER NOT NULL,
\
soa_serial INTEGER NOT NULL,
\
timestamp TEXT NOT NULL,
\
record BLOB NOT NULL
\
)",
[],
)?;
assert_eq!(count, 1);
Ok(1)
}
}
pub struct JournalIter<'j> {
current_row_id: i64,
journal: &'j Journal,
}
impl<'j> JournalIter<'j> {
fn new(journal: &'j Journal) -> Self {
JournalIter {
current_row_id: 0,
journal,
}
}
}
impl Iterator for JournalIter<'_> {
type Item = Record;
fn next(&mut self) -> Option<Self::Item> {
match self.journal.select_record(self.current_row_id + 1) {
Ok(Some((row_id, record))) => {
self.current_row_id = row_id;
Some(record)
}
Ok(None) => None,
Err(error) => {
error!(%error, "persistence error while iterating over journal");
None
}
}
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum PersistenceError {
#[error("error recovering from journal: {}", _0)]
Recovery(&'static str),
#[error("wrong insert count: {} expect: {}", got, expect)]
WrongInsertCount {
got: usize,
expect: usize,
},
#[error("proto error: {0}")]
Proto(#[from] ProtoError),
#[cfg(feature = "sqlite")]
#[error("sqlite error: {0}")]
Sqlite(#[from] rusqlite::Error),
#[error("request timed out")]
Timeout,
}