use std::sync::{Arc, Weak};
use rusqlite::hooks::Action;
use tokio::sync::Notify;
use tokio_rusqlite::Connection;
use super::{Db, StorageError};
pub struct UpdateHookGuard {
conn: Connection,
}
impl std::fmt::Debug for UpdateHookGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UpdateHookGuard").finish_non_exhaustive()
}
}
impl Drop for UpdateHookGuard {
fn drop(&mut self) {
let conn = self.conn.clone();
if tokio::runtime::Handle::try_current().is_ok() {
tokio::spawn(async move {
let _ = conn
.call(|c| {
c.update_hook(None::<fn(Action, &str, &str, i64)>);
Ok::<_, rusqlite::Error>(())
})
.await;
});
}
}
}
pub async fn register_tasks_update_hook(
db: &Db,
notify: Arc<Notify>,
) -> Result<UpdateHookGuard, StorageError> {
let weak: Weak<Notify> = Arc::downgrade(¬ify);
let conn = db.conn.clone();
let conn_for_guard = conn.clone();
conn.call(move |c| {
c.update_hook(Some(
move |action: Action, _db: &str, table: &str, _rowid: i64| {
if table == "tasks"
&& matches!(action, Action::SQLITE_INSERT | Action::SQLITE_UPDATE)
&& let Some(n) = weak.upgrade()
{
n.notify_one();
}
},
));
Ok::<_, rusqlite::Error>(())
})
.await?;
Ok(UpdateHookGuard {
conn: conn_for_guard,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Db;
use std::time::Duration;
#[tokio::test]
async fn insert_via_same_db_fires_hook() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("rover.db");
let db = Db::open(&path).await.unwrap();
let notify = Arc::new(Notify::new());
let _guard = register_tasks_update_hook(&db, notify.clone())
.await
.unwrap();
db.conn
.call(|c| {
c.execute(
"INSERT INTO tasks (id, kind, status, created_at, updated_at, params_json) \
VALUES (?1, 'batch_fetch', 'pending', 0, 0, '{}')",
rusqlite::params!["hook-test-id"],
)?;
Ok::<(), rusqlite::Error>(())
})
.await
.unwrap();
tokio::time::timeout(Duration::from_millis(200), notify.notified())
.await
.expect("notify did not fire within 200ms");
}
}