use crate::errors::{Error, Result};
use anyhow::Context;
use rusqlite::{params, Connection, OptionalExtension, Transaction};
#[derive(Copy, Clone, Debug, PartialOrd, Ord, PartialEq, Eq)]
pub(super) struct DbVersion(pub u32, pub u32);
type UpgradeFn = fn(&Transaction) -> Result<()>;
const VERSIONS: &[(DbVersion, UpgradeFn)] = &[
(DbVersion(0, 1), upgrade_to_0_1),
(DbVersion(0, 2), upgrade_to_0_2),
];
pub(super) const LATEST_VERSION: DbVersion = VERSIONS[VERSIONS.len() - 1].0;
pub(super) fn upgrade_db(con: &mut Connection) -> Result<()> {
let mut current_version = get_db_version(con)?;
if current_version.0 > LATEST_VERSION.0 {
return Err(Error::Database(
"Database is too new for this version of TaskChampion".into(),
));
}
for (version, upgrade) in VERSIONS {
if current_version < *version {
let t = con.transaction()?;
upgrade(&t)?;
t.commit()?;
current_version = *version;
}
}
Ok(())
}
fn upgrade_to_0_1(t: &Transaction) -> Result<()> {
let create_tables = vec![
"CREATE TABLE IF NOT EXISTS operations (id INTEGER PRIMARY KEY AUTOINCREMENT, data STRING);",
"CREATE TABLE IF NOT EXISTS sync_meta (key STRING PRIMARY KEY, value STRING);",
"CREATE TABLE IF NOT EXISTS tasks (uuid STRING PRIMARY KEY, data STRING);",
"CREATE TABLE IF NOT EXISTS working_set (id INTEGER PRIMARY KEY, uuid STRING);",
];
for q in create_tables {
t.execute(q, []).context("Creating table")?;
}
if !has_column(t, "operations", "uuid")? {
t.execute(
r#"ALTER TABLE operations ADD COLUMN uuid GENERATED ALWAYS AS (
coalesce(json_extract(data, "$.Update.uuid"),
json_extract(data, "$.Create.uuid"),
json_extract(data, "$.Delete.uuid"))) VIRTUAL"#,
[],
)
.context("Adding operations.uuid")?;
t.execute("CREATE INDEX operations_by_uuid ON operations (uuid)", [])
.context("Creating operations_by_uuid")?;
}
if !has_column(t, "operations", "synced")? {
t.execute(
"ALTER TABLE operations ADD COLUMN synced bool DEFAULT false",
[],
)
.context("Adding operations.synced")?;
t.execute(
"CREATE INDEX operations_by_synced ON operations (synced)",
[],
)
.context("Creating operations_by_synced")?;
}
create_version_table(t)?;
set_db_version(t, DbVersion(0, 1))?;
Ok(())
}
fn upgrade_to_0_2(t: &Transaction) -> Result<()> {
t.execute(r#"DROP INDEX operations_by_uuid"#, [])
.context("Dropping index operatoins_by_uuid")?;
t.execute(r#"ALTER TABLE operations DROP COLUMN uuid"#, [])
.context("Removing incorrect operations.uuid")?;
t.execute(
r#"ALTER TABLE operations ADD COLUMN uuid GENERATED ALWAYS AS (
coalesce(json_extract(data, '$.Update.uuid'),
json_extract(data, '$.Create.uuid'),
json_extract(data, '$.Delete.uuid'))) VIRTUAL"#,
[],
)
.context("Creating correct operations.uuid")?;
t.execute("CREATE INDEX operations_by_uuid ON operations (uuid)", [])
.context("Creating index operations_by_uuid")?;
set_db_version(t, DbVersion(0, 2))?;
Ok(())
}
fn create_version_table(t: &Transaction) -> Result<()> {
t.execute(
r#"CREATE TABLE IF NOT EXISTS version (
singleton INTEGER PRIMARY KEY CHECK (singleton = 0),
major INTEGER,
minor INTEGER)"#,
[],
)
.context("Creating table")?;
Ok(())
}
pub(super) fn get_db_version(con: &mut Connection) -> Result<DbVersion> {
let version: Option<(u32, u32)> = match con
.query_row("SELECT major, minor FROM version", [], |r| {
Ok((r.get("major")?, r.get("minor")?))
})
.optional()
{
Ok(v) => v,
Err(err @ rusqlite::Error::SqliteFailure(_, _)) => {
if has_column(&con.transaction()?, "version", "major")? {
return Err(err.into());
}
None
}
Err(err) => return Err(err.into()),
};
let (major, minor) = version.unwrap_or((0, 0));
Ok(DbVersion(major, minor))
}
fn set_db_version(t: &Transaction, version: DbVersion) -> Result<()> {
let DbVersion(major, minor) = version;
t.execute(
r#"INSERT INTO version (singleton, major, minor) VALUES (0, ?, ?)
ON CONFLICT(singleton) do UPDATE SET major=?, minor=?"#,
params![major, minor, major, minor],
)?;
Ok(())
}
fn has_column(t: &Transaction, table: &str, column: &str) -> Result<bool> {
let res: u32 = t
.query_row(
"SELECT COUNT(*) AS c FROM pragma_table_xinfo(?) WHERE name=?",
[table, column],
|r| r.get(0),
)
.with_context(|| format!("Checking for {table}.{column}"))?;
Ok(res > 0)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn get_db_version_no_table() -> Result<()> {
let mut con = Connection::open_in_memory()?;
assert_eq!(get_db_version(&mut con)?, DbVersion(0, 0));
Ok(())
}
#[test]
fn get_db_version_empty() -> Result<()> {
let mut con = Connection::open_in_memory()?;
let t = con.transaction()?;
create_version_table(&t)?;
t.commit()?;
assert_eq!(get_db_version(&mut con)?, DbVersion(0, 0));
Ok(())
}
#[test]
fn get_db_version_set() -> Result<()> {
let mut con = Connection::open_in_memory()?;
let t = con.transaction()?;
create_version_table(&t)?;
set_db_version(&t, DbVersion(3, 5))?;
t.commit()?;
assert_eq!(get_db_version(&mut con)?, DbVersion(3, 5));
Ok(())
}
#[test]
fn get_db_version_set_twice() -> Result<()> {
let mut con = Connection::open_in_memory()?;
let t = con.transaction()?;
create_version_table(&t)?;
set_db_version(&t, DbVersion(3, 5))?;
set_db_version(&t, DbVersion(4, 7))?;
t.commit()?;
assert_eq!(get_db_version(&mut con)?, DbVersion(4, 7));
Ok(())
}
#[test]
fn test_upgrade_to_0_1() -> Result<()> {
let mut con = Connection::open_in_memory()?;
{
let t = con.transaction()?;
upgrade_to_0_1(&t)?;
t.commit()?;
}
{
let t = con.transaction()?;
assert!(has_column(&t, "operations", "id")?);
assert!(has_column(&t, "operations", "data")?);
assert!(has_column(&t, "operations", "uuid")?);
assert!(has_column(&t, "sync_meta", "key")?);
assert!(has_column(&t, "sync_meta", "value")?);
assert!(has_column(&t, "tasks", "uuid")?);
assert!(has_column(&t, "tasks", "data")?);
assert!(has_column(&t, "working_set", "id")?);
assert!(has_column(&t, "working_set", "uuid")?);
}
assert_eq!(get_db_version(&mut con)?, DbVersion(0, 1));
Ok(())
}
#[test]
fn test_upgrade_to_0_2() -> Result<()> {
let mut con = Connection::open_in_memory()?;
{
let t = con.transaction()?;
upgrade_to_0_1(&t)?;
upgrade_to_0_2(&t)?;
t.commit()?;
}
{
let t = con.transaction()?;
assert!(has_column(&t, "operations", "id")?);
assert!(has_column(&t, "operations", "data")?);
assert!(has_column(&t, "operations", "uuid")?);
assert!(has_column(&t, "sync_meta", "key")?);
assert!(has_column(&t, "sync_meta", "value")?);
assert!(has_column(&t, "tasks", "uuid")?);
assert!(has_column(&t, "tasks", "data")?);
assert!(has_column(&t, "working_set", "id")?);
assert!(has_column(&t, "working_set", "uuid")?);
}
assert_eq!(get_db_version(&mut con)?, DbVersion(0, 2));
Ok(())
}
}