use libsql::{Builder, Connection, Database, params};
use serde::{Serialize, de::DeserializeOwned};
use std::future::Future;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use tokio::runtime::{Handle, Runtime};
#[derive(Debug)]
pub(crate) struct Error(String);
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for Error {}
impl From<libsql::Error> for Error {
fn from(e: libsql::Error) -> Self {
Error(e.to_string())
}
}
impl From<bincode::Error> for Error {
fn from(e: bincode::Error) -> Self {
Error(e.to_string())
}
}
#[derive(Clone)]
pub(crate) struct CaCache {
inner: Arc<Inner>,
max_size_bytes: u64,
}
struct Inner {
conn: Connection,
#[allow(dead_code)]
db: Database,
runtime: Option<Runtime>,
}
impl Inner {
fn block_on<F: Future + Send>(&self, fut: F) -> F::Output
where
F::Output: Send,
{
block_on_helper(&self.runtime, fut)
}
}
fn block_on_helper<F: Future + Send>(runtime: &Option<Runtime>, fut: F) -> F::Output
where
F::Output: Send,
{
if let Ok(handle) = Handle::try_current() {
return match handle.runtime_flavor() {
tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| handle.block_on(fut))
}
_ => spawn_scoped_runtime(fut),
};
}
if let Some(rt) = runtime {
return rt.block_on(fut);
}
spawn_scoped_runtime(fut)
}
fn spawn_scoped_runtime<F: Future + Send>(fut: F) -> F::Output
where
F::Output: Send,
{
std::thread::scope(|s| {
s.spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build tokio runtime worker thread");
rt.block_on(fut)
})
.join()
.expect("libsql worker thread panicked")
})
}
fn maybe_build_runtime() -> Result<Option<Runtime>, Error> {
if Handle::try_current().is_ok() {
return Ok(None);
}
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map(Some)
.map_err(|e| Error(format!("tokio runtime: {e}")))
}
impl CaCache {
pub(crate) fn open(path: &Path, max_size_bytes: u64) -> Result<Self, Error> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| Error(format!("create_dir_all: {e}")))?;
}
let runtime = maybe_build_runtime()?;
let init = async {
let db = Builder::new_local(path).build().await?;
let conn = db.connect()?;
conn.execute_batch(
"PRAGMA journal_mode=WAL;
PRAGMA synchronous=NORMAL;
PRAGMA busy_timeout=5000;
CREATE TABLE IF NOT EXISTS ca_entries (
hash BLOB NOT NULL,
extr_ver TEXT NOT NULL,
grammar TEXT NOT NULL,
payload BLOB NOT NULL,
last_used INTEGER NOT NULL DEFAULT (strftime('%s','now')),
PRIMARY KEY (hash, extr_ver, grammar)
) WITHOUT ROWID;",
)
.await?;
Ok::<_, libsql::Error>((db, conn))
};
let (db, conn) = block_on_helper(&runtime, init)?;
Ok(Self {
inner: Arc::new(Inner { conn, db, runtime }),
max_size_bytes,
})
}
pub(crate) fn default_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("normalize")
.join("ca-cache.sqlite")
}
pub(crate) fn get<T: DeserializeOwned>(
&self,
hash: &[u8],
extr_ver: &str,
grammar: &str,
) -> Result<Option<T>, Error> {
let now = unix_now();
let conn = &self.inner.conn;
let bytes_opt: Option<Vec<u8>> = self.inner.block_on(async {
let mut rows = conn
.query(
"SELECT payload FROM ca_entries WHERE hash = ?1 AND extr_ver = ?2 AND grammar = ?3",
params![hash, extr_ver, grammar],
)
.await?;
let row = rows.next().await?;
if let Some(row) = row {
let bytes: Vec<u8> = row.get(0)?;
let _ = conn
.execute(
"UPDATE ca_entries SET last_used = ?1 WHERE hash = ?2 AND extr_ver = ?3 AND grammar = ?4",
params![now, hash, extr_ver, grammar],
)
.await;
Ok::<_, libsql::Error>(Some(bytes))
} else {
Ok(None)
}
})?;
if let Some(bytes) = bytes_opt {
let value: T = bincode::deserialize(&bytes)?;
return Ok(Some(value));
}
Ok(None)
}
pub(crate) fn put<T: Serialize>(
&self,
hash: &[u8],
extr_ver: &str,
grammar: &str,
value: &T,
) -> Result<(), Error> {
let bytes = bincode::serialize(value)?;
let now = unix_now();
let conn = &self.inner.conn;
self.inner.block_on(async {
conn.execute(
"INSERT OR REPLACE INTO ca_entries (hash, extr_ver, grammar, payload, last_used)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![hash, extr_ver, grammar, bytes, now],
)
.await
})?;
Ok(())
}
pub(crate) fn gc_stale_versions(&self, current_extr_ver: &str) -> Result<usize, Error> {
let conn = &self.inner.conn;
let n = self.inner.block_on(async {
conn.execute(
"DELETE FROM ca_entries WHERE extr_ver != ?1 AND extr_ver NOT LIKE 'symbols-%'",
params![current_extr_ver],
)
.await
})?;
Ok(n as usize)
}
pub(crate) fn gc_stale_symbol_versions(
&self,
current_versions: &[&str],
) -> Result<usize, Error> {
let placeholders: String = current_versions
.iter()
.enumerate()
.map(|(i, _)| format!("?{}", i + 1))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"DELETE FROM ca_entries WHERE extr_ver LIKE 'symbols-%' AND extr_ver NOT IN ({placeholders})"
);
let owned: Vec<String> = current_versions.iter().map(|s| s.to_string()).collect();
let conn = &self.inner.conn;
let n = self.inner.block_on(async {
let values: Vec<libsql::Value> = owned
.iter()
.map(|s| libsql::Value::Text(s.clone()))
.collect();
conn.execute(&sql, values).await
})?;
Ok(n as usize)
}
#[allow(dead_code)] pub(crate) fn evict_if_over_limit(&self) -> Result<(), Error> {
let conn = &self.inner.conn;
let max = self.max_size_bytes;
self.inner.block_on(async {
let size = current_db_size(conn).await.unwrap_or(0);
if size <= max {
return Ok::<_, libsql::Error>(());
}
let target = max * 9 / 10;
loop {
let current = current_db_size(conn).await.unwrap_or(0);
if current <= target {
break;
}
let deleted = conn
.execute(
"DELETE FROM ca_entries WHERE (hash, extr_ver, grammar) IN (
SELECT hash, extr_ver, grammar FROM ca_entries ORDER BY last_used ASC LIMIT 100
)",
(),
)
.await?;
if deleted == 0 {
break;
}
}
conn.execute_batch("VACUUM;").await?;
Ok(())
})?;
Ok(())
}
}
async fn current_db_size(conn: &Connection) -> Result<u64, libsql::Error> {
let mut rows = conn
.query(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
(),
)
.await?;
let n: i64 = if let Some(row) = rows.next().await? {
row.get(0).unwrap_or(0)
} else {
0
};
Ok(n as u64)
}
fn unix_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
static SYMBOL_CACHE: OnceLock<Option<CaCache>> = OnceLock::new();
pub(crate) const SYMBOL_CACHE_VERSIONS: &[&str] = &["symbols-v1-all", "symbols-v1-public"];
pub(crate) fn symbol_cache() -> Option<&'static CaCache> {
SYMBOL_CACHE
.get_or_init(|| {
let path = CaCache::default_path();
match CaCache::open(&path, 512 * 1024 * 1024) {
Ok(cache) => {
if let Err(e) = cache.gc_stale_symbol_versions(SYMBOL_CACHE_VERSIONS) {
tracing::debug!("normalize-facts: symbol cache GC error: {}", e);
}
Some(cache)
}
Err(e) => {
tracing::debug!(
"normalize-facts: symbol cache unavailable at {}: {}",
path.display(),
e
);
None
}
}
})
.as_ref()
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use tempfile::NamedTempFile;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Payload {
symbols: Vec<String>,
count: u32,
}
fn temp_cache() -> CaCache {
let f = NamedTempFile::new().unwrap();
let path = f.path().to_path_buf();
std::mem::forget(f); CaCache::open(&path, 1024 * 1024 * 1024).unwrap()
}
#[test]
fn round_trip() {
let cache = temp_cache();
let hash = blake3::hash(b"hello world");
let payload = Payload {
symbols: vec!["foo".into(), "bar".into()],
count: 42,
};
cache.put(hash.as_bytes(), "v1", "rust", &payload).unwrap();
let got: Option<Payload> = cache.get(hash.as_bytes(), "v1", "rust").unwrap();
assert_eq!(got, Some(payload));
}
#[test]
fn version_mismatch_returns_none() {
let cache = temp_cache();
let hash = blake3::hash(b"hello world");
let payload = Payload {
symbols: vec![],
count: 1,
};
cache.put(hash.as_bytes(), "v1", "rust", &payload).unwrap();
let got: Option<Payload> = cache.get(hash.as_bytes(), "v2", "rust").unwrap();
assert_eq!(got, None);
}
#[test]
fn gc_stale_versions() {
let cache = temp_cache();
let hash = blake3::hash(b"test");
let payload = Payload {
symbols: vec![],
count: 0,
};
cache.put(hash.as_bytes(), "old", "rust", &payload).unwrap();
cache
.put(hash.as_bytes(), "current", "rust", &payload)
.unwrap();
let deleted = cache.gc_stale_versions("current").unwrap();
assert_eq!(deleted, 1);
let got: Option<Payload> = cache.get(hash.as_bytes(), "old", "rust").unwrap();
assert_eq!(got, None);
let got: Option<Payload> = cache.get(hash.as_bytes(), "current", "rust").unwrap();
assert!(got.is_some());
}
#[test]
fn gc_stale_versions_preserves_symbol_cache() {
let cache = temp_cache();
let hash = blake3::hash(b"test");
let payload = Payload {
symbols: vec![],
count: 0,
};
cache.put(hash.as_bytes(), "old", "rust", &payload).unwrap();
cache
.put(hash.as_bytes(), "symbols-v1-all", "rust", &payload)
.unwrap();
let deleted = cache.gc_stale_versions("current").unwrap();
assert_eq!(deleted, 1);
let got: Option<Payload> = cache.get(hash.as_bytes(), "old", "rust").unwrap();
assert_eq!(got, None);
let got: Option<Payload> = cache
.get(hash.as_bytes(), "symbols-v1-all", "rust")
.unwrap();
assert!(
got.is_some(),
"symbol cache entries must survive extraction GC"
);
}
#[test]
fn gc_stale_symbol_versions() {
let cache = temp_cache();
let hash = blake3::hash(b"sym-test");
let payload = Payload {
symbols: vec![],
count: 0,
};
cache
.put(hash.as_bytes(), "symbols-v0-all", "rust", &payload)
.unwrap();
cache
.put(hash.as_bytes(), "symbols-v1-all", "rust", &payload)
.unwrap();
cache
.put(hash.as_bytes(), "symbols-v1-public", "rust", &payload)
.unwrap();
let deleted = cache
.gc_stale_symbol_versions(&["symbols-v1-all", "symbols-v1-public"])
.unwrap();
assert_eq!(deleted, 1);
let got: Option<Payload> = cache
.get(hash.as_bytes(), "symbols-v0-all", "rust")
.unwrap();
assert_eq!(got, None);
let got: Option<Payload> = cache
.get(hash.as_bytes(), "symbols-v1-all", "rust")
.unwrap();
assert!(got.is_some());
}
#[test]
fn eviction_under_limit() {
let cache = temp_cache();
for i in 0u32..10 {
let hash = blake3::hash(i.to_le_bytes().as_slice());
let payload = Payload {
symbols: vec!["x".repeat(1000)],
count: i,
};
cache.put(hash.as_bytes(), "v1", "rust", &payload).unwrap();
}
cache.evict_if_over_limit().unwrap();
}
}