use std::{path::Path, sync::Arc, time::Duration};
use sqry_core::graph::CodeGraph;
use sqry_db::queries::{CalleesQuery, CallersQuery, RelationKey};
use tracing::warn;
const MAX_DERIVED_WARMUP_SYMBOLS: usize = 64;
pub trait SqrydHook: Send + Sync + std::fmt::Debug {
fn on_publish(&self, workspace_root: &Path, graph: Arc<CodeGraph>);
}
impl<T: SqrydHook + ?Sized> SqrydHook for Arc<T> {
fn on_publish(&self, workspace_root: &Path, graph: Arc<CodeGraph>) {
(**self).on_publish(workspace_root, graph);
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoOpHook;
impl SqrydHook for NoOpHook {
fn on_publish(&self, _workspace_root: &Path, _graph: Arc<CodeGraph>) {
}
}
pub type SharedHook = Arc<dyn SqrydHook>;
#[must_use]
pub fn noop_hook() -> SharedHook {
Arc::new(NoOpHook)
}
pub fn spawn_hook<F, Fut, E>(
timeout: Duration,
workspace_root: std::path::PathBuf,
task_label: &'static str,
fut_factory: F,
) where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<(), E>> + Send + 'static,
E: std::fmt::Display + Send + 'static,
{
tokio::spawn(async move {
let fut = fut_factory();
match tokio::time::timeout(timeout, fut).await {
Ok(Ok(())) => {}
Ok(Err(err)) => {
warn!(
task = task_label,
workspace = %workspace_root.display(),
error = %err,
"sqryd hook {task_label} failed (absorbed; query path continues)",
);
}
Err(_elapsed) => {
let timeout_ms = u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX);
warn!(
task = task_label,
workspace = %workspace_root.display(),
timeout_ms,
"sqryd hook {task_label} timed out (absorbed; query path continues)",
);
}
}
});
}
#[derive(Debug, Clone)]
pub struct QueryDbHook {
timeout: Duration,
query_db_config: sqry_db::QueryDbConfig,
}
impl QueryDbHook {
#[must_use]
pub fn new(timeout: Duration) -> Arc<Self> {
Arc::new(Self {
timeout,
query_db_config: sqry_db::QueryDbConfig::default(),
})
}
#[must_use]
pub fn with_query_db_config(
timeout: Duration,
query_db_config: sqry_db::QueryDbConfig,
) -> Arc<Self> {
Arc::new(Self {
timeout,
query_db_config,
})
}
#[must_use]
pub fn timeout(&self) -> Duration {
self.timeout
}
}
impl SqrydHook for QueryDbHook {
fn on_publish(&self, workspace_root: &Path, graph: Arc<CodeGraph>) {
let timeout = self.timeout;
let query_db_config = self.query_db_config.clone();
let workspace_root_owned = workspace_root.to_path_buf();
spawn_hook::<_, _, anyhow::Error>(
timeout,
workspace_root_owned.clone(),
"query-db-save-derived",
move || {
let workspace_root = workspace_root_owned;
let graph = graph;
let query_db_config = query_db_config;
async move { run_save_derived(workspace_root, graph, query_db_config).await }
},
);
}
}
async fn run_save_derived(
workspace_root: std::path::PathBuf,
graph: Arc<CodeGraph>,
query_db_config: sqry_db::QueryDbConfig,
) -> anyhow::Result<()> {
tokio::task::spawn_blocking(move || {
run_save_derived_blocking(&workspace_root, &graph, query_db_config)
})
.await
.map_err(|join_err| anyhow::anyhow!("spawn_blocking(query-db-save-derived) join: {join_err}"))?
}
fn run_save_derived_blocking(
workspace_root: &Path,
graph: &CodeGraph,
query_db_config: sqry_db::QueryDbConfig,
) -> anyhow::Result<()> {
let graph_dir = workspace_root.join(".sqry").join("graph");
let snapshot_path = graph_dir.join("snapshot.sqry");
std::fs::create_dir_all(&graph_dir)?;
sqry_core::graph::unified::persistence::save_to_path(graph, &snapshot_path)?;
let sha = sqry_db::persistence::compute_file_sha256(&snapshot_path).map_err(|io_err| {
anyhow::anyhow!("compute_file_sha256({}): {io_err}", snapshot_path.display())
})?;
let persisted_graph =
sqry_core::graph::unified::persistence::load_from_path(&snapshot_path, None)?;
let snapshot_arc = Arc::new(persisted_graph.snapshot());
let db = sqry_db::QueryDb::new(snapshot_arc, query_db_config);
let warmed_entries = warm_persistent_queries(&db);
tracing::debug!(
workspace = %workspace_root.display(),
warmed_entries,
"QueryDbHook: warmed persistent query entries before derived-cache save"
);
let derived_path = sqry_db::derived_path(workspace_root, db.config());
sqry_db::persistence::save_derived(&db, sha, &derived_path, workspace_root)?;
Ok(())
}
fn warm_persistent_queries(db: &sqry_db::QueryDb) -> usize {
let mut symbol_names = std::collections::BTreeSet::new();
for (_node_id, node) in db.snapshot().iter_nodes() {
if node.is_unified_loser() {
continue;
}
if let Some(name) = db.snapshot().strings().resolve(node.name) {
symbol_names.insert(name.to_string());
}
if let Some(qualified_name_id) = node.qualified_name
&& let Some(qualified_name) = db.snapshot().strings().resolve(qualified_name_id)
{
symbol_names.insert(qualified_name.to_string());
}
if symbol_names.len() >= MAX_DERIVED_WARMUP_SYMBOLS {
break;
}
}
let mut warmed_entries = 0usize;
for symbol_name in symbol_names {
let key = RelationKey::exact(symbol_name);
let _ = db.get::<CallersQuery>(&key);
let _ = db.get::<CalleesQuery>(&key);
warmed_entries += 2;
}
warmed_entries
}
#[doc(hidden)]
#[derive(Debug, Default)]
pub struct RecordingHook {
pub invocations: parking_lot::Mutex<Vec<std::path::PathBuf>>,
}
impl RecordingHook {
#[must_use]
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
#[must_use]
pub fn invocation_count(&self) -> usize {
self.invocations.lock().len()
}
#[must_use]
pub fn invocation_roots(&self) -> Vec<std::path::PathBuf> {
self.invocations.lock().clone()
}
}
impl SqrydHook for RecordingHook {
fn on_publish(&self, workspace_root: &Path, _graph: Arc<CodeGraph>) {
self.invocations.lock().push(workspace_root.to_path_buf());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn noop_hook_compiles_through_shared_dispatch() {
let hook: SharedHook = noop_hook();
let graph = Arc::new(CodeGraph::new());
hook.on_publish(Path::new("/repos/example"), graph);
}
#[test]
fn recording_hook_captures_invocations_in_order() {
let hook = RecordingHook::new();
let graph = Arc::new(CodeGraph::new());
hook.on_publish(Path::new("/repos/a"), Arc::clone(&graph));
hook.on_publish(Path::new("/repos/b"), Arc::clone(&graph));
assert_eq!(hook.invocation_count(), 2);
let roots = hook.invocation_roots();
assert_eq!(roots[0], Path::new("/repos/a"));
assert_eq!(roots[1], Path::new("/repos/b"));
}
#[tokio::test]
async fn spawn_hook_absorbs_error() {
spawn_hook::<_, _, &'static str>(
Duration::from_millis(100),
std::path::PathBuf::from("/repos/example"),
"test-hook",
|| async { Err("simulated failure") },
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn spawn_hook_absorbs_timeout() {
spawn_hook::<_, _, &'static str>(
Duration::from_millis(10),
std::path::PathBuf::from("/repos/example"),
"test-hook",
|| async {
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(())
},
);
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[test]
fn pf03a_query_db_hook_constructs_with_requested_timeout() {
let hook = QueryDbHook::new(Duration::from_millis(1234));
assert_eq!(hook.timeout(), Duration::from_millis(1234));
}
#[test]
fn pf03a_query_db_hook_accepts_custom_config() {
let cfg = sqry_db::QueryDbConfig::default();
let hook = QueryDbHook::with_query_db_config(Duration::from_millis(50), cfg);
assert_eq!(hook.timeout(), Duration::from_millis(50));
}
#[tokio::test]
async fn pf09_query_db_hook_no_snapshot_file_writes_snapshot_and_derived() {
let workspace = tempfile::tempdir().expect("tempdir");
let hook = QueryDbHook::new(Duration::from_secs(2));
let graph = Arc::new(CodeGraph::new());
std::fs::create_dir_all(workspace.path().join(".sqry").join("graph")).unwrap();
hook.on_publish(workspace.path(), graph);
tokio::time::sleep(Duration::from_millis(500)).await;
let snapshot = workspace
.path()
.join(".sqry")
.join("graph")
.join("snapshot.sqry");
let derived = workspace
.path()
.join(".sqry")
.join("graph")
.join("derived.sqry");
assert!(
snapshot.exists(),
"PF09: hook must write the published snapshot when it is absent (got {})",
snapshot.display()
);
assert!(
derived.exists(),
"PF09: hook must create derived.sqry after writing snapshot.sqry (got {})",
derived.display()
);
}
#[tokio::test]
async fn pf03a_query_db_hook_absorbs_save_failure() {
let workspace = tempfile::tempdir().expect("tempdir");
let hook = QueryDbHook::new(Duration::from_secs(1));
let graph = Arc::new(CodeGraph::new());
let snap_dir = workspace.path().join(".sqry").join("graph");
std::fs::create_dir_all(&snap_dir).unwrap();
std::fs::create_dir_all(snap_dir.join("snapshot.sqry")).unwrap();
hook.on_publish(workspace.path(), graph);
tokio::time::sleep(Duration::from_millis(100)).await;
let derived = snap_dir.join("derived.sqry");
assert!(
!derived.exists(),
"PF03A: a hashing failure must not leave a partially-written derived.sqry"
);
}
#[tokio::test]
async fn pf03a_query_db_hook_writes_derived_sqry_when_snapshot_present() {
let workspace = tempfile::tempdir().expect("tempdir");
let snap_dir = workspace.path().join(".sqry").join("graph");
std::fs::create_dir_all(&snap_dir).unwrap();
let graph_owned = CodeGraph::new();
let hook = QueryDbHook::new(Duration::from_secs(5));
let graph = Arc::new(graph_owned);
hook.on_publish(workspace.path(), graph);
let derived = snap_dir.join("derived.sqry");
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while std::time::Instant::now() < deadline {
if derived.exists() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
assert!(
derived.exists(),
"PF03A: hook must write derived.sqry to {} within 5s",
derived.display()
);
let bytes = std::fs::read(&derived).unwrap();
assert!(bytes.len() >= sqry_db::DERIVED_MAGIC.len());
assert_eq!(
&bytes[..sqry_db::DERIVED_MAGIC.len()],
sqry_db::DERIVED_MAGIC,
"derived.sqry must start with SQRY_DERIVED_V02 magic"
);
}
}