mod migrations;
use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt;
use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc;
use std::time;
use radicle_cob::ObjectId;
use sqlite as sql;
use thiserror::Error;
use crate::prelude::RepoId;
use crate::sql::transaction;
pub const COBS_DB_FILE: &str = "cache.db";
const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);
const MIGRATIONS: &[Migration] = &[
Migration::Sql(include_str!("cache/migrations/1.sql")),
Migration::Native(migrations::_2::run),
Migration::Sql(include_str!("cache/migrations/3.sql")),
];
type MigrateFn = fn(&sql::Connection, &Progress, &mut dyn MigrateCallback) -> Result<usize, Error>;
enum Migration {
Sql(&'static str),
Native(MigrateFn),
}
#[derive(Debug)]
pub struct MigrateProgress<'a> {
pub migration: &'a Progress,
pub rows: &'a Progress,
}
impl MigrateProgress<'_> {
pub fn is_done(&self) -> bool {
self.migration.current() == self.migration.total()
&& self.rows.current() == self.rows.total()
}
}
pub trait MigrateCallback {
fn progress(&mut self, progress: MigrateProgress<'_>);
}
impl<F> MigrateCallback for F
where
F: Fn(MigrateProgress),
{
fn progress(&mut self, progress: MigrateProgress) {
(self)(progress)
}
}
pub mod migrate {
use super::*;
pub fn log(progress: MigrateProgress<'_>) {
log::trace!(
target: "db",
"Migration {}/{} in progress.. ({}%)",
progress.migration.current(),
progress.migration.total(),
progress.rows.percentage()
);
}
pub fn ignore(_progress: MigrateProgress<'_>) {}
}
#[derive(Error, Debug)]
pub enum Error {
#[error("internal error: {0}")]
Internal(#[from] sql::Error),
#[error("malformed JSON schema")]
MalformedJsonSchema,
#[error("malformed JSON data: {0}")]
MalformedJson(serde_json::Error),
#[error("no rows returned")]
NoRows,
#[error("collaborative objects database is out of date")]
OutOfDate,
}
pub type StoreWriter = Store<Write>;
pub type StoreReader = Store<Read>;
#[derive(Clone)]
pub struct Read;
#[derive(Clone)]
pub struct Write;
#[derive(Clone)]
pub struct Store<T> {
pub(super) db: Arc<sql::ConnectionThreadSafe>,
marker: PhantomData<T>,
}
impl<T> fmt::Debug for Store<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Database").finish()
}
}
impl Store<Read> {
pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let mut db = sql::Connection::open_thread_safe_with_flags(
path,
sqlite::OpenFlags::new().with_read_only(),
)?;
db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
pub fn memory() -> Result<Self, Error> {
let mut db = sql::Connection::open_thread_safe_with_flags(
":memory:",
sqlite::OpenFlags::new().with_read_only(),
)?;
db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
}
impl Store<Write> {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let mut db = sql::Connection::open_thread_safe(path)?;
db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
pub fn memory() -> Result<Self, Error> {
let db = Arc::new(sql::Connection::open_thread_safe(":memory:")?);
Ok(Self {
db,
marker: PhantomData,
})
}
pub fn with_migrations<M: MigrateCallback>(mut self, callback: M) -> Result<Self, Error> {
self.migrate(callback).map(|_| self)
}
pub fn read_only(self) -> Store<Read> {
Store {
db: self.db,
marker: PhantomData,
}
}
pub fn raw_query<T, E, F>(&self, query: F) -> Result<T, E>
where
F: FnOnce(&sql::Connection) -> Result<T, E>,
E: From<sql::Error>,
{
transaction(&self.db, query)
}
pub fn migrate<M: MigrateCallback>(&mut self, callback: M) -> Result<usize, Error> {
self.migrate_to(MIGRATIONS.len(), callback)
}
pub fn migrate_to<M: MigrateCallback>(
&mut self,
target: usize,
mut callback: M,
) -> Result<usize, Error> {
let db = &self.db;
let mut version = version(db)?;
let total = MIGRATIONS.len();
for (i, migration) in MIGRATIONS.iter().enumerate().take(target).skip(version) {
let current = i + 1;
transaction(db, |db| {
match migration {
Migration::Sql(query) => {
db.execute(query)?;
callback.progress(MigrateProgress {
migration: &Progress { total, current },
rows: &Progress::done(1),
});
}
Migration::Native(migrate) => {
migrate(db, &Progress { total, current }, &mut callback)?;
}
}
version = bump(db)?;
Ok::<_, Error>(())
})?;
}
Ok(version)
}
}
impl<T> Store<T> {
pub fn version(&self) -> Result<usize, Error> {
version(&self.db)
}
pub fn check_version(&self) -> Result<(), Error> {
if version(&self.db)? < MIGRATIONS.len() {
return Err(Error::OutOfDate);
}
Ok(())
}
}
pub fn version(db: &sql::Connection) -> Result<usize, Error> {
let version = db
.prepare("PRAGMA user_version")?
.into_iter()
.next()
.ok_or(Error::NoRows)??
.read::<i64, _>(0);
Ok(version as usize)
}
fn bump(db: &sql::Connection) -> Result<usize, Error> {
let old = version(db)?;
let new = old + 1;
db.execute(format!("PRAGMA user_version = {new}"))?;
Ok(new as usize)
}
pub trait Update<T> {
type Out;
type UpdateError: std::error::Error + Send + Sync + 'static;
fn update(
&mut self,
rid: &RepoId,
id: &ObjectId,
object: &T,
) -> Result<Self::Out, Self::UpdateError>;
}
pub trait Remove<T> {
type Out;
type RemoveError: std::error::Error + Send + Sync + 'static;
fn remove(&mut self, id: &ObjectId) -> Result<Self::Out, Self::RemoveError>;
fn remove_all(&mut self, rid: &RepoId) -> Result<Self::Out, Self::RemoveError>;
}
#[derive(Clone, Debug)]
pub struct InMemory<T> {
inner: HashMap<RepoId, HashMap<ObjectId, T>>,
}
impl<T> Default for InMemory<T> {
fn default() -> Self {
Self {
inner: HashMap::new(),
}
}
}
impl<T> Update<T> for InMemory<T>
where
T: Clone,
{
type Out = Option<T>;
type UpdateError = Infallible;
fn update(
&mut self,
rid: &RepoId,
id: &ObjectId,
object: &T,
) -> Result<Self::Out, Self::UpdateError> {
let objects = self.inner.entry(*rid).or_default();
Ok(objects.insert(*id, object.clone()))
}
}
pub struct NoCache;
impl<T> Update<T> for NoCache {
type Out = ();
type UpdateError = Infallible;
fn update(
&mut self,
_rid: &RepoId,
_id: &ObjectId,
_object: &T,
) -> Result<Self::Out, Self::UpdateError> {
Ok(())
}
}
impl<T> Remove<T> for NoCache {
type Out = ();
type RemoveError = Infallible;
fn remove(&mut self, _id: &ObjectId) -> Result<Self::Out, Self::RemoveError> {
Ok(())
}
fn remove_all(&mut self, _rid: &RepoId) -> Result<Self::Out, Self::RemoveError> {
Ok(())
}
}
#[derive(Debug)]
pub struct Progress {
current: usize,
total: usize,
}
impl Progress {
pub fn new(total: usize) -> Self {
Self { current: 0, total }
}
pub fn done(total: usize) -> Self {
Self {
current: total,
total,
}
}
pub fn inc(&mut self) {
self.current += 1;
}
pub fn total(&self) -> usize {
self.total
}
pub fn current(&self) -> usize {
self.current
}
pub fn percentage(&self) -> f32 {
if self.total == 0 {
100.
} else {
(self.current as f32 / self.total as f32) * 100.0
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::assert_matches;
#[test]
fn test_check_version() {
let mut db = StoreWriter::memory().unwrap();
assert_matches!(db.check_version(), Err(Error::OutOfDate));
db.migrate(migrate::ignore).unwrap();
assert_matches!(db.check_version(), Ok(()));
}
#[test]
fn test_migrate_to() {
let mut db = StoreWriter::memory().unwrap();
assert_eq!(db.version().unwrap(), 0);
assert_eq!(db.migrate_to(1, migrate::ignore).unwrap(), 1); assert_eq!(db.version().unwrap(), 1);
assert_eq!(db.migrate_to(2, migrate::ignore).unwrap(), 2); assert_eq!(db.version().unwrap(), 2);
assert_eq!(db.migrate_to(3, migrate::ignore).unwrap(), 3); assert_eq!(db.version().unwrap(), 3);
assert_eq!(db.migrate_to(1, migrate::ignore).unwrap(), 3); assert_eq!(db.version().unwrap(), 3);
assert_eq!(db.migrate_to(99, migrate::ignore).unwrap(), 3); assert_eq!(db.version().unwrap(), 3);
}
}