use std::path::Path;
use std::sync::Arc;
use sqlx::SqlitePool;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::cache::{CacheStats, ClawCache};
use crate::config::ClawConfig;
use crate::error::{ClawError, ClawResult};
use crate::snapshot::{SnapshotManifest, SnapshotMeta, Snapshotter};
use crate::store::memory::{ListOptions, MemoryRecord, MemoryStore, MemoryType};
use crate::store::session_lifecycle::{Session, SessionLifecycleStore};
use crate::store::tool_output::{ToolOutputRecord, ToolOutputStore};
#[derive(Debug, Clone)]
pub struct DbStats {
pub memory_count: u64,
pub session_count: u64,
pub tool_output_count: u64,
}
#[derive(Debug, Clone)]
pub struct ClawStats {
pub total_memories: u64,
pub cache_hit_rate: f64,
pub last_snapshot_at: Option<chrono::DateTime<chrono::Utc>>,
pub db_size_bytes: u64,
pub wal_size_bytes: u64,
}
#[derive(Debug)]
pub struct ClawEngine {
pub(crate) config: ClawConfig,
pub(crate) pool: SqlitePool,
cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
stats: Arc<Mutex<CacheStats>>,
last_snapshot_at: Arc<Mutex<Option<chrono::DateTime<chrono::Utc>>>>,
}
impl ClawEngine {
pub async fn open(config: ClawConfig) -> ClawResult<Self> {
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions};
use std::str::FromStr;
let db_url = format!("sqlite:{}", config.db_path.display());
let journal_mode = match config.journal_mode {
crate::config::JournalMode::WAL => SqliteJournalMode::Wal,
crate::config::JournalMode::Delete => SqliteJournalMode::Delete,
crate::config::JournalMode::Truncate => SqliteJournalMode::Truncate,
};
let connect_options = SqliteConnectOptions::from_str(&db_url)
.map_err(|e| ClawError::Config(format!("invalid database URL: {e}")))?
.create_if_missing(true)
.journal_mode(journal_mode);
let pool = SqlitePoolOptions::new()
.max_connections(config.max_connections)
.connect_with(connect_options)
.await?;
#[cfg(feature = "encryption")]
if let Some(key) = &config.encryption_key {
let hex: String = key.iter().map(|b| format!("{b:02x}")).collect();
sqlx::query(&format!("PRAGMA key = \"x'{hex}'\""))
.execute(&pool)
.await?;
}
let cache_cap = ((config.cache_size_mb * 1024 * 1024) / 512).max(64);
let cache = Arc::new(Mutex::new(ClawCache::new(cache_cap)?));
let stats = Arc::new(Mutex::new(CacheStats::new()));
let engine = ClawEngine {
config,
pool,
cache,
stats,
last_snapshot_at: Arc::new(Mutex::new(None)),
};
if engine.config.auto_migrate {
engine.migrate().await?;
}
Ok(engine)
}
pub async fn open_default() -> ClawResult<Self> {
ClawEngine::open(ClawConfig::default()).await
}
pub async fn migrate(&self) -> ClawResult<()> {
crate::schema::migrations::run_migrations(&self.pool).await
}
pub fn pool(&self) -> &SqlitePool {
&self.pool
}
pub fn config(&self) -> &ClawConfig {
&self.config
}
pub async fn close(self) {
self.pool.close().await;
}
#[tracing::instrument(skip(self, record), fields(memory_id = %record.id))]
pub async fn insert_memory(&self, record: &MemoryRecord) -> ClawResult<Uuid> {
MemoryStore::new(&self.pool).insert(record).await?;
let mut cache = self.cache.lock().await;
let mut stats = self.stats.lock().await;
cache.insert(record.id, record.clone());
stats.insert_count += 1;
Ok(record.id)
}
#[tracing::instrument(skip(self), fields(memory_id = %id))]
pub async fn get_memory(&self, id: Uuid) -> ClawResult<MemoryRecord> {
{
let mut cache = self.cache.lock().await;
let mut stats = self.stats.lock().await;
if let Some(record) = cache.get(&id) {
stats.record_hit();
return Ok(record.clone());
}
stats.record_miss();
}
let record = MemoryStore::new(&self.pool).get(id).await?;
let mut cache = self.cache.lock().await;
cache.insert(record.id, record.clone());
Ok(record)
}
#[tracing::instrument(skip(self), fields(memory_id = %id))]
pub async fn update_memory(&self, id: Uuid, content: &str) -> ClawResult<()> {
let updated_at = chrono::Utc::now();
MemoryStore::new(&self.pool)
.update_content(id, content, updated_at)
.await?;
let mut cache = self.cache.lock().await;
cache.invalidate(&id);
Ok(())
}
#[tracing::instrument(skip(self), fields(memory_id = %id))]
pub async fn delete_memory(&self, id: Uuid) -> ClawResult<()> {
MemoryStore::new(&self.pool).delete(id).await?;
let mut cache = self.cache.lock().await;
cache.invalidate(&id);
Ok(())
}
#[tracing::instrument(skip(self))]
pub async fn list_memories(
&self,
type_filter: Option<MemoryType>,
) -> ClawResult<Vec<MemoryRecord>> {
MemoryStore::new(&self.pool)
.list(type_filter.as_ref())
.await
}
#[tracing::instrument(skip(self))]
pub async fn list_memories_paginated(
&self,
type_filter: Option<MemoryType>,
opts: ListOptions,
) -> ClawResult<(Vec<MemoryRecord>, Option<String>)> {
MemoryStore::new(&self.pool)
.list_paginated(type_filter.as_ref(), &opts)
.await
}
#[tracing::instrument(skip(self))]
pub async fn search_by_tag(&self, tag: &str) -> ClawResult<Vec<MemoryRecord>> {
MemoryStore::new(&self.pool).search_by_tag(tag).await
}
#[tracing::instrument(skip(self))]
pub async fn fts_search(&self, query: &str) -> ClawResult<Vec<MemoryRecord>> {
MemoryStore::new(&self.pool).fts_search(query).await
}
pub async fn expire_ttl_memories(&self) -> ClawResult<u64> {
let deleted = MemoryStore::new(&self.pool).expire_ttl().await?;
if deleted > 0 {
let mut cache = self.cache.lock().await;
cache.clear();
}
Ok(deleted)
}
pub async fn start_session(&self) -> ClawResult<String> {
SessionLifecycleStore::new(&self.pool).start().await
}
pub async fn end_session(&self, session_id: &str) -> ClawResult<()> {
SessionLifecycleStore::new(&self.pool).end(session_id).await
}
pub async fn get_session(&self, session_id: &str) -> ClawResult<Session> {
SessionLifecycleStore::new(&self.pool).get(session_id).await
}
pub async fn list_sessions(&self) -> ClawResult<Vec<Session>> {
SessionLifecycleStore::new(&self.pool).list().await
}
pub async fn record_tool_output(&self, output: &ToolOutputRecord) -> ClawResult<()> {
ToolOutputStore::new(&self.pool).insert(output).await
}
pub async fn list_tool_outputs(&self, session_id: &str) -> ClawResult<Vec<ToolOutputRecord>> {
ToolOutputStore::new(&self.pool)
.get_by_session(session_id)
.await
}
pub async fn transaction(&self) -> ClawResult<crate::transaction::ClawTransaction<'_>> {
crate::transaction::ClawTransaction::begin(self).await
}
pub async fn begin_transaction(&self) -> ClawResult<crate::transaction::ClawTransaction<'_>> {
crate::transaction::ClawTransaction::begin(self).await
}
#[tracing::instrument(skip(self))]
pub async fn snapshot_create(&self) -> ClawResult<SnapshotMeta> {
let snap_dir = self.config.snapshot_dir.as_ref().ok_or_else(|| {
ClawError::Config("snapshot_dir must be set to use snapshot_create".to_string())
})?;
sqlx::query("PRAGMA wal_checkpoint(FULL)")
.execute(&self.pool)
.await?;
let snapshotter = Snapshotter::new(snap_dir)?;
let meta = snapshotter.take(&self.config.db_path)?;
*self.last_snapshot_at.lock().await = Some(meta.created_at);
Ok(meta)
}
#[tracing::instrument(skip(self), fields(snapshot = %snapshot_path.display()))]
pub async fn restore(&mut self, snapshot_path: &Path) -> ClawResult<()> {
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::str::FromStr;
Self::validate_sqlite_magic(snapshot_path)?;
self.pool.close().await;
std::fs::copy(snapshot_path, &self.config.db_path)
.map_err(|e| ClawError::Snapshot(format!("failed to restore snapshot: {e}")))?;
let db_name = self
.config
.db_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.into_owned();
let db_parent = self
.config
.db_path
.parent()
.unwrap_or(std::path::Path::new("."));
for suffix in &["-wal", "-shm"] {
let sidecar = db_parent.join(format!("{db_name}{suffix}"));
if sidecar.exists() {
let _ = std::fs::remove_file(&sidecar);
}
}
let db_url = format!("sqlite:{}", self.config.db_path.display());
let connect_options = SqliteConnectOptions::from_str(&db_url)
.map_err(|e| ClawError::Config(format!("invalid database URL: {e}")))?
.create_if_missing(false);
self.pool = SqlitePoolOptions::new()
.max_connections(self.config.max_connections)
.connect_with(connect_options)
.await?;
#[cfg(feature = "encryption")]
if let Some(key) = &self.config.encryption_key {
let hex: String = key.iter().map(|b| format!("{b:02x}")).collect();
sqlx::query(&format!("PRAGMA key = \"x'{hex}'\""))
.execute(&self.pool)
.await?;
}
self.migrate().await?;
self.cache.lock().await.clear();
tracing::info!(
snapshot = %snapshot_path.display(),
db = %self.config.db_path.display(),
"database restored from snapshot"
);
Ok(())
}
pub fn snapshot_manifest(&self) -> ClawResult<SnapshotManifest> {
let snap_dir = self
.config
.snapshot_dir
.as_ref()
.ok_or_else(|| ClawError::Config("snapshot_dir must be set".to_string()))?;
Snapshotter::new(snap_dir)?.load_manifest()
}
#[cfg(feature = "encryption")]
pub async fn rotate_key(&self, _old_key: [u8; 32], new_key: [u8; 32]) -> ClawResult<()> {
let hex: String = new_key.iter().map(|b| format!("{b:02x}")).collect();
sqlx::query(&format!("PRAGMA rekey = \"x'{hex}'\""))
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn cache_stats(&self) -> CacheStats {
self.stats.lock().await.clone()
}
#[tracing::instrument(skip(self))]
pub async fn stats(&self) -> ClawResult<ClawStats> {
let (total_memories,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM memories")
.fetch_one(&self.pool)
.await?;
let cache_hit_rate = self.stats.lock().await.rolling_hit_rate();
let last_snapshot_at = *self.last_snapshot_at.lock().await;
let db_size_bytes = std::fs::metadata(&self.config.db_path)
.map(|m| m.len())
.unwrap_or(0);
let wal_path = {
let p = self.config.db_path.to_string_lossy();
std::path::PathBuf::from(format!("{p}-wal"))
};
let wal_size_bytes = std::fs::metadata(&wal_path).map(|m| m.len()).unwrap_or(0);
Ok(ClawStats {
total_memories: total_memories as u64,
cache_hit_rate,
last_snapshot_at,
db_size_bytes,
wal_size_bytes,
})
}
pub async fn db_stats(&self) -> ClawResult<DbStats> {
let (mc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM memories")
.fetch_one(&self.pool)
.await?;
let (sc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sessions")
.fetch_one(&self.pool)
.await?;
let (tc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tool_output")
.fetch_one(&self.pool)
.await?;
Ok(DbStats {
memory_count: mc as u64,
session_count: sc as u64,
tool_output_count: tc as u64,
})
}
fn validate_sqlite_magic(path: &Path) -> ClawResult<()> {
use std::io::Read;
const SQLITE_MAGIC: &[u8; 16] = b"SQLite format 3\0";
let mut header = [0u8; 16];
let mut file = std::fs::File::open(path)
.map_err(|e| ClawError::Snapshot(format!("cannot open snapshot: {e}")))?;
file.read_exact(&mut header)
.map_err(|e| ClawError::Snapshot(format!("cannot read snapshot header: {e}")))?;
if &header != SQLITE_MAGIC {
return Err(ClawError::Snapshot(
"file does not have a valid SQLite 3 header".to_string(),
));
}
Ok(())
}
}