use rusqlite::types::Value;
use rusqlite::{params, OptionalExtension};
use tokio_rusqlite::Connection;
use tracing::{debug, instrument, trace, Level};
use crate::error::Result;
use crate::proto::codec::Change;
pub const CRSQL_TRACKED_TAG_WHOLE_DATABASE: i32 = 0;
pub const CRSQL_TRACKED_EVENT_RECEIVE: i32 = 0;
#[derive(Clone)]
pub struct DB {
conn: Connection,
}
impl DB {
#[instrument(level = Level::DEBUG, err)]
pub async fn new(db_path: Option<&str>, ext_path: &str) -> Result<DB> {
let conn = match db_path {
Some(path) => Connection::open(path).await,
None => Connection::open_in_memory().await,
}?;
let load_ext_path = ext_path.to_owned();
conn.call(move |c| {
unsafe {
c.load_extension_enable()?;
let r = c.load_extension(load_ext_path, Some("sqlite3_crsqlite_init"))?;
c.load_extension_disable()?;
r
};
Ok(())
})
.await?;
Ok(DB { conn })
}
#[instrument(skip(self), level = Level::DEBUG, ret, err)]
pub async fn tracked_peer_version(&self, site_id: Vec<u8>) -> Result<i64> {
let version = match self
.conn
.call(move |conn| {
conn.query_row(
"
select max(coalesce(version, 0)) from crsql_tracked_peers
where site_id = ? and event = ?",
params![site_id, CRSQL_TRACKED_EVENT_RECEIVE],
|row| row.get::<usize, Option<i64>>(0),
)
.optional()
})
.await?
{
Some(Some(tracked_version)) => tracked_version,
_ => 0,
};
Ok(version)
}
#[instrument(skip(self, changes), level = Level::DEBUG, ret, err)]
pub async fn merge(&self, site_id: Vec<u8>, changes: Vec<Change>) -> Result<i64> {
let result = self
.conn
.call(move |conn| {
let tx = conn.transaction()?;
let mut ins_changes = tx.prepare(
"
insert into crsql_changes (
\"table\", pk, cid, val, col_version, db_version, site_id, cl, seq)
values (?, ?, ?, ?, ?, ?, ?, ?, ?);",
)?;
let mut max_db_version = 0;
for change in changes {
if change.db_version > max_db_version {
max_db_version = change.db_version;
}
ins_changes.execute(params![
change.table,
change.pk,
change.cid,
change.val,
change.col_version,
change.db_version,
&site_id,
change.cl,
change.seq,
])?;
trace!("merge change {:?} {:?}", change, site_id);
}
trace!(site_id = format!("{:?}", site_id), max_db_version);
tx.execute(
"
insert into crsql_tracked_peers (site_id, version, tag, event)
values (?, ?, ?, ?)
on conflict do update set version = max(version, excluded.version)",
params![
site_id,
max_db_version,
CRSQL_TRACKED_TAG_WHOLE_DATABASE,
CRSQL_TRACKED_EVENT_RECEIVE
],
)?;
drop(ins_changes);
tx.commit()?;
Ok(max_db_version)
})
.await?;
Ok(result)
}
#[instrument(skip(self), level = Level::DEBUG, err)]
pub async fn close(self) -> Result<()> {
self.conn
.call(|c| {
c.query_row("SELECT crsql_finalize()", [], |_row| Ok(()))?;
Ok(())
})
.await?;
debug!("crsql_finalized");
Ok(())
}
#[instrument(skip(self), level = Level::DEBUG, ret, err)]
pub async fn status(&self) -> Result<(Vec<u8>, i64)> {
let (site_id, db_version) = self.conn
.call(|c| {
c.query_row(
"
select crsql_site_id(), coalesce((select max(db_version) from crsql_changes where site_id is null), 0);",
[],
|row| Ok((row.get::<usize, Vec<u8>>(0)?, row.get::<usize, i64>(1)?)),
)
})
.await?;
Ok((site_id, db_version))
}
#[instrument(skip(self), level = Level::DEBUG)]
pub async fn changes(&self, since_db_version: i64) -> Result<(Vec<u8>, Vec<Change>)> {
let (site_id, db_version) = self.status().await?;
if since_db_version >= db_version {
return Ok((site_id, vec![]));
}
let changes = self
.conn
.call(move |c| {
let mut stmt = c.prepare(
"
select
\"table\",
pk,
cid,
val,
col_version,
db_version,
cl,
seq
from crsql_changes
where db_version > ?
and site_id is null",
)?;
let mut result = vec![];
let mut rows = stmt.query([since_db_version])?;
while let Some(row) = rows.next()? {
let change = Change {
table: row.get::<usize, String>(0)?,
pk: row.get::<usize, Vec<u8>>(1)?,
cid: row.get::<usize, String>(2)?,
val: row.get::<usize, Value>(3)?,
col_version: row.get::<usize, i64>(4)?,
db_version: row.get::<usize, i64>(5)?,
cl: row.get::<usize, i64>(6)?,
seq: row.get::<usize, i64>(7)?,
};
result.push(change);
}
Ok(result)
})
.await?;
Ok((site_id, changes))
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn new_db() -> DB {
DB::new(None, "target/debug/crsqlite")
.await
.expect("new db")
}
#[tokio::test]
async fn version_advances() {
let test_points: Vec<(i32, i32, String)> = vec![
(1, 2, "👾".to_string()),
(2, 3, "🤖".to_string()),
(3, 5, "😀".to_string()),
(5, 8, "🪣".to_string()),
];
let db = new_db().await;
let (site_id, version) = db.status().await.expect("status");
assert_eq!(site_id.len(), 16, "site_id is expected length");
assert_ne!(site_id.as_slice(), [0; 16], "site_id is not all zeroes");
assert_eq!(version, 0i64, "initial version");
db.conn
.call(move |conn| {
conn.execute(
"create table canvas (x integer, y integer, value text, primary key (x, y));",
[],
)?;
conn.query_row("select crsql_as_crr('canvas');", [], |_| Ok(()))?;
for i in 0..test_points.len() {
conn.execute(
"insert into canvas (x, y, value) values (?, ?, ?)",
params![test_points[i].0, test_points[i].1, test_points[i].2],
)?;
}
Ok(())
})
.await
.expect("create table with rows");
let (site_id_2, version_2) = db.status().await.expect("status");
assert_eq!(site_id, site_id_2);
assert_eq!(version_2, 4i64);
db.close().await.expect("close db");
}
#[tokio::test]
async fn merge_changes() {
let db_1 = new_db().await;
let db_2 = new_db().await;
let (site_id_1, version_1) = db_1.status().await.expect("status");
let (site_id_2, version_2) = db_2.status().await.expect("status");
assert_ne!(site_id_1, site_id_2);
db_1.conn
.call(move |conn| {
conn.execute(
"create table canvas (x integer, y integer, value text, primary key (x, y));",
[],
)?;
conn.query_row("select crsql_as_crr('canvas');", [], |_| Ok(()))?;
conn.execute(
"insert into canvas (x, y, value) values (?, ?, ?)",
params![1, 1, "🌟"],
)?;
conn.execute(
"insert into canvas (x, y, value) values (?, ?, ?)",
params![2, 2, "🚀"],
)?;
Ok(())
})
.await
.expect("make changes in db1");
db_2.conn
.call(move |conn| {
conn.execute(
"create table canvas (x integer, y integer, value text, primary key (x, y));",
[],
)?;
conn.query_row("select crsql_as_crr('canvas');", [], |_| Ok(()))?;
Ok(())
})
.await
.expect("create table in db_2");
let changes_1 = db_1
.changes(version_1)
.await
.expect("get changes from db_1")
.1;
let merged_version = db_2
.merge(site_id_1, changes_1)
.await
.expect("merge changes into db_2");
assert_eq!(db_1.status().await.expect("db_1 status").1, merged_version);
assert!(merged_version > version_2);
let result = db_2
.conn
.call(move |conn| {
let mut stmt = conn.prepare("select x, y, value from canvas order by x, y")?;
let mut result = vec![];
let mut rows = stmt.query([])?;
while let Some(row) = rows.next()? {
result.push((
row.get::<usize, i32>(0)?,
row.get::<usize, i32>(1)?,
row.get::<usize, String>(2)?,
));
}
Ok(result)
})
.await
.expect("contents of db_2");
assert_eq!(
result,
vec![(1, 1, "🌟".to_string()), (2, 2, "🚀".to_string()),]
);
db_1.close().await.expect("close db_1");
db_2.close().await.expect("close db_2");
}
}