use anyhow::{Context, Result};
use parking_lot::Mutex as PMutex;
use rusqlite::Connection;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
#[cfg(test)]
use std::sync::Mutex;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
use super::migrations::{migrate_legacy_embeddings_to_sidecar, purge_global_topic_trees};
use super::recovery::{is_io_open_error, try_cleanup_stale_files};
use super::schema::SCHEMA;
use super::{db_path_for, SQLITE_BUSY_TIMEOUT};
use crate::memory::config::MemoryConfig;
#[cfg(test)]
static SCHEMA_APPLY_COUNTS: OnceLock<Mutex<HashMap<PathBuf, usize>>> = OnceLock::new();
fn record_schema_apply(_path: &Path) {
#[cfg(test)]
{
let counts = SCHEMA_APPLY_COUNTS.get_or_init(|| Mutex::new(HashMap::new()));
let mut guard = counts.lock().expect("schema apply count mutex poisoned");
*guard.entry(_path.to_path_buf()).or_insert(0) += 1;
}
}
#[cfg(test)]
#[doc(hidden)]
pub(crate) fn schema_apply_count_for_path_for_tests(path: &Path) -> usize {
SCHEMA_APPLY_COUNTS
.get()
.and_then(|m| {
m.lock()
.ok()
.map(|guard| guard.get(path).copied().unwrap_or(0))
})
.unwrap_or(0)
}
pub(crate) const CB_THRESHOLD: u32 = 3;
pub(crate) const CB_COOLDOWN: Duration = Duration::from_secs(30);
struct CircuitBreaker {
consecutive_failures: AtomicU32,
tripped: AtomicBool,
last_trip: PMutex<Option<Instant>>,
}
impl CircuitBreaker {
fn new() -> Self {
Self {
consecutive_failures: AtomicU32::new(0),
tripped: AtomicBool::new(false),
last_trip: PMutex::new(None),
}
}
fn record_success(&self) -> bool {
self.consecutive_failures.store(0, Ordering::Relaxed);
*self.last_trip.lock() = None;
self.tripped.swap(false, Ordering::Relaxed)
}
fn record_failure(&self) -> bool {
let prev = self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
let count = prev + 1;
if count >= CB_THRESHOLD && !self.tripped.swap(true, Ordering::Relaxed) {
*self.last_trip.lock() = Some(Instant::now());
return true;
}
if self.tripped.load(Ordering::Relaxed) {
*self.last_trip.lock() = Some(Instant::now());
}
false
}
fn is_open(&self) -> bool {
if !self.tripped.load(Ordering::Relaxed) {
return false;
}
let guard = self.last_trip.lock();
matches!(*guard, Some(t) if t.elapsed() < CB_COOLDOWN)
}
}
struct ConnectionCache {
connections: PMutex<HashMap<PathBuf, Arc<PMutex<Connection>>>>,
breakers: PMutex<HashMap<PathBuf, Arc<CircuitBreaker>>>,
init_locks: PMutex<HashMap<PathBuf, Arc<PMutex<()>>>>,
}
static CONN_CACHE: OnceLock<ConnectionCache> = OnceLock::new();
fn conn_cache() -> &'static ConnectionCache {
CONN_CACHE.get_or_init(|| ConnectionCache {
connections: PMutex::new(HashMap::new()),
breakers: PMutex::new(HashMap::new()),
init_locks: PMutex::new(HashMap::new()),
})
}
fn init_db(conn: &Connection, config: &MemoryConfig) -> Result<()> {
conn.busy_timeout(SQLITE_BUSY_TIMEOUT)
.context("Failed to configure chunk DB busy timeout")?;
conn.execute_batch("PRAGMA foreign_keys = ON;")
.context("Failed to enable chunk DB foreign_keys pragma")?;
conn.execute_batch("PRAGMA synchronous = FULL;")
.context("Failed to set chunk DB synchronous=FULL")?;
apply_schema(conn)?;
migrate_legacy_embeddings_to_sidecar(conn, config)?;
purge_global_topic_trees(conn, config)?;
Ok(())
}
fn apply_schema(conn: &Connection) -> Result<()> {
let journal_mode: String = conn
.query_row("PRAGMA journal_mode=TRUNCATE", [], |row| row.get(0))
.context("Failed to set chunk DB journal_mode=TRUNCATE")?;
if !journal_mode.eq_ignore_ascii_case("truncate") {}
conn.execute_batch(SCHEMA)
.context("Failed to initialize chunk DB schema")?;
add_column_if_missing(conn, "mem_tree_chunks", "embedding", "BLOB")?;
add_column_if_missing(conn, "mem_tree_score", "llm_importance", "REAL")?;
add_column_if_missing(conn, "mem_tree_score", "llm_importance_reason", "TEXT")?;
add_column_if_missing(conn, "mem_tree_chunks", "parent_summary_id", "TEXT")?;
add_column_if_missing(conn, "mem_tree_summaries", "embedding", "BLOB")?;
add_column_if_missing(
conn,
"mem_tree_chunks",
"lifecycle_status",
"TEXT NOT NULL DEFAULT 'admitted'",
)?;
conn.execute_batch(
"CREATE INDEX IF NOT EXISTS idx_mem_tree_chunks_lifecycle \
ON mem_tree_chunks(lifecycle_status);",
)
.context("Failed to create mem_tree_chunks lifecycle index")?;
add_column_if_missing(conn, "mem_tree_chunks", "path_scope", "TEXT")?;
add_column_if_missing(conn, "mem_tree_chunks", "content_path", "TEXT")?;
add_column_if_missing(conn, "mem_tree_chunks", "content_sha256", "TEXT")?;
add_column_if_missing(conn, "mem_tree_summaries", "content_path", "TEXT")?;
add_column_if_missing(conn, "mem_tree_summaries", "content_sha256", "TEXT")?;
add_column_if_missing(conn, "mem_tree_summaries", "doc_id", "TEXT")?;
add_column_if_missing(conn, "mem_tree_summaries", "version_ms", "INTEGER")?;
conn.execute_batch(
"CREATE INDEX IF NOT EXISTS idx_mem_tree_summaries_doc_version \
ON mem_tree_summaries(tree_id, doc_id, version_ms);",
)
.context("Failed to create mem_tree_summaries doc/version index")?;
add_column_if_missing(conn, "mem_tree_chunks", "raw_refs_json", "TEXT")?;
add_column_if_missing(
conn,
"mem_tree_entity_index",
"is_user",
"INTEGER NOT NULL DEFAULT 0",
)?;
add_column_if_missing(conn, "mem_tree_jobs", "failure_reason", "TEXT")?;
add_column_if_missing(conn, "mem_tree_jobs", "failure_class", "TEXT")?;
Ok(())
}
pub(super) fn add_column_if_missing(
conn: &Connection,
table: &str,
name: &str,
sql_type: &str,
) -> Result<()> {
match conn.execute(
&format!("ALTER TABLE {table} ADD COLUMN {name} {sql_type}"),
[],
) {
Ok(_) => Ok(()),
Err(err) if err.to_string().contains("duplicate column name") => Ok(()),
Err(err) => Err(err).with_context(|| format!("Failed to add column {table}.{name}")),
}
}
pub(crate) fn get_or_init_connection(config: &MemoryConfig) -> Result<Arc<PMutex<Connection>>> {
let db_path = db_path_for(config);
{
let breakers = conn_cache().breakers.lock();
if let Some(breaker) = breakers.get(&db_path) {
if breaker.is_open() {
anyhow::bail!(
"[chunks] circuit breaker open for {}: too many consecutive init failures",
db_path.display()
);
}
}
}
{
let guard = conn_cache().connections.lock();
if let Some(conn) = guard.get(&db_path) {
return Ok(Arc::clone(conn));
}
}
let init_lock = {
let mut guard = conn_cache().init_locks.lock();
guard
.entry(db_path.clone())
.or_insert_with(|| Arc::new(PMutex::new(())))
.clone()
};
let _init_guard = init_lock.lock();
{
let guard = conn_cache().connections.lock();
if let Some(conn) = guard.get(&db_path) {
return Ok(Arc::clone(conn));
}
}
let conn = open_and_init(&db_path, config).or_else(|first_err| {
if is_io_open_error(&first_err) {
try_cleanup_stale_files(&db_path);
open_and_init(&db_path, config)
} else {
Err(first_err)
}
});
match conn {
Ok(conn) => {
let arc_conn = Arc::new(PMutex::new(conn));
conn_cache()
.connections
.lock()
.insert(db_path.clone(), Arc::clone(&arc_conn));
let breaker = {
let mut guard = conn_cache().breakers.lock();
guard
.entry(db_path.clone())
.or_insert_with(|| Arc::new(CircuitBreaker::new()))
.clone()
};
if breaker.record_success() {}
Ok(arc_conn)
}
Err(err) => {
let breaker = {
let mut guard = conn_cache().breakers.lock();
guard
.entry(db_path.clone())
.or_insert_with(|| Arc::new(CircuitBreaker::new()))
.clone()
};
if breaker.record_failure() {}
Err(err)
}
}
}
fn open_and_init(db_path: &Path, config: &MemoryConfig) -> Result<Connection> {
let dir = db_path.parent().expect("db_path always has a parent");
std::fs::create_dir_all(dir)
.with_context(|| format!("Failed to create chunk DB dir: {}", dir.display()))?;
let conn = Connection::open(db_path)
.with_context(|| format!("Failed to open chunk DB: {}", db_path.display()))?;
init_db(&conn, config)
.with_context(|| format!("Failed to init chunk DB schema: {}", db_path.display()))?;
record_schema_apply(db_path);
Ok(conn)
}
#[allow(dead_code)]
pub(crate) fn invalidate_connection(config: &MemoryConfig) {
let db_path = db_path_for(config);
conn_cache().connections.lock().remove(&db_path);
conn_cache().breakers.lock().remove(&db_path);
}
pub(super) fn drop_cached_connection(config: &MemoryConfig) {
let db_path = db_path_for(config);
conn_cache().connections.lock().remove(&db_path);
conn_cache().breakers.lock().remove(&db_path);
}
#[cfg(test)]
pub(crate) fn clear_connection_cache() {
conn_cache().connections.lock().clear();
conn_cache().breakers.lock().clear();
conn_cache().init_locks.lock().clear();
}
pub fn with_connection<T>(
config: &MemoryConfig,
f: impl FnOnce(&Connection) -> Result<T>,
) -> Result<T> {
let conn_arc = get_or_init_connection(config)?;
let guard = conn_arc.lock();
f(&guard)
}
#[cfg(test)]
#[path = "connection_tests.rs"]
mod tests;