use std::collections::VecDeque;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use chrono::{DateTime, TimeZone, Utc};
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions};
use sqlx::SqlitePool;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::cache::{CacheStats, ClawCache};
use crate::config::ClawConfig;
use crate::error::{ClawError, ClawResult};
use crate::snapshot::{
blake3_file_hex, manifest_path_for, verify_snapshot_integrity, SnapshotManifest, SnapshotMeta,
};
use crate::store::memory::{ListOptions, ListPage, MemoryRecord, MemoryStore, MemoryType};
use crate::store::session_lifecycle::{Session, SessionLifecycleStore};
use crate::store::tool_output::{ToolOutputRecord, ToolOutputStore};
#[derive(Debug, Clone)]
pub struct DbStats {
pub memory_count: u64,
pub session_count: u64,
pub tool_output_count: u64,
}
#[derive(Debug, Clone)]
pub struct ClawStats {
pub total_memories: u64,
pub total_sessions: u64,
pub cache_hit_rate: f64,
pub cache_size: usize,
pub db_size_bytes: u64,
pub wal_size_bytes: u64,
pub last_snapshot_at: Option<DateTime<Utc>>,
}
#[derive(Debug)]
pub struct ClawEngine {
pub(crate) config: ClawConfig,
pub(crate) pool: SqlitePool,
pub(crate) cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
stats: Arc<Mutex<CacheStats>>,
last_snapshot_at: Arc<Mutex<Option<DateTime<Utc>>>>,
cache_hits: AtomicU64,
cache_misses: AtomicU64,
read_window: Arc<Mutex<VecDeque<bool>>>,
}
impl ClawEngine {
#[tracing::instrument(skip(config), fields(workspace_id = %config.workspace_id))]
pub async fn open(config: ClawConfig) -> ClawResult<Self> {
let pool = Self::connect_pool(&config, true).await?;
#[cfg(feature = "encryption")]
if let Some(key) = &config.encryption_key {
Self::apply_pragmas_key(&pool, key).await?;
}
let cache_cap = ((config.cache_size_mb * 1024 * 1024) / 512).max(64);
let cache = Arc::new(Mutex::new(ClawCache::new(cache_cap)?));
let engine = ClawEngine {
config,
pool,
cache,
stats: Arc::new(Mutex::new(CacheStats::new())),
last_snapshot_at: Arc::new(Mutex::new(None)),
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
read_window: Arc::new(Mutex::new(VecDeque::with_capacity(1000))),
};
if engine.config.auto_migrate {
engine.migrate().await?;
}
Ok(engine)
}
#[tracing::instrument(fields(workspace_id = "default"))]
pub async fn open_default() -> ClawResult<Self> {
ClawEngine::open(ClawConfig::default()).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn migrate(&self) -> ClawResult<()> {
crate::schema::migrations::run_migrations(&self.pool).await
}
pub fn pool(&self) -> &SqlitePool {
&self.pool
}
pub fn config(&self) -> &ClawConfig {
&self.config
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn close(self) {
self.pool.close().await;
}
#[tracing::instrument(skip(self, record), fields(workspace_id = %self.config.workspace_id, memory_id = %record.id))]
pub async fn insert_memory(&self, record: &MemoryRecord) -> ClawResult<Uuid> {
MemoryStore::new(&self.pool).insert(record).await?;
let mut cache = self.cache.lock().await;
let mut stats = self.stats.lock().await;
cache.insert(record.id, record.clone());
stats.insert_count += 1;
Ok(record.id)
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, memory_id = %id))]
pub async fn get_memory(&self, id: Uuid) -> ClawResult<MemoryRecord> {
{
let mut cache = self.cache.lock().await;
if let Some(record) = cache.get(&id) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
self.push_read_window(true).await;
let mut stats = self.stats.lock().await;
stats.record_hit();
return Ok(record.clone());
}
}
self.cache_misses.fetch_add(1, Ordering::Relaxed);
self.push_read_window(false).await;
{
let mut stats = self.stats.lock().await;
stats.record_miss();
}
let record = MemoryStore::new(&self.pool).get(id).await?;
let mut cache = self.cache.lock().await;
cache.insert(record.id, record.clone());
Ok(record)
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, memory_id = %id))]
pub async fn update_memory(&self, id: Uuid, content: &str) -> ClawResult<()> {
let updated_at = Utc::now();
MemoryStore::new(&self.pool)
.update_content(id, content, updated_at)
.await?;
self.cache.lock().await.invalidate(&id);
Ok(())
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, memory_id = %id))]
pub async fn delete_memory(&self, id: Uuid) -> ClawResult<()> {
MemoryStore::new(&self.pool).delete(id).await?;
self.cache.lock().await.invalidate(&id);
Ok(())
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn list_memories(
&self,
type_filter: Option<MemoryType>,
) -> ClawResult<Vec<MemoryRecord>> {
MemoryStore::new(&self.pool)
.list(type_filter.as_ref())
.await
}
#[tracing::instrument(skip(self, opts), fields(workspace_id = %self.config.workspace_id))]
pub async fn list_memories_paginated(
&self,
type_filter: Option<MemoryType>,
opts: ListOptions,
) -> ClawResult<ListPage<MemoryRecord>> {
MemoryStore::new(&self.pool)
.list_paginated(type_filter.as_ref(), &opts)
.await
}
#[tracing::instrument(skip(self, opts), fields(workspace_id = %self.config.workspace_id))]
pub async fn get_memories_by_type(
&self,
memory_type: MemoryType,
opts: Option<ListOptions>,
) -> ClawResult<ListPage<MemoryRecord>> {
let options = opts.unwrap_or_default();
MemoryStore::new(&self.pool)
.list_paginated(Some(&memory_type), &options)
.await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn search_by_tag(&self, tag: &str) -> ClawResult<Vec<MemoryRecord>> {
MemoryStore::new(&self.pool).search_by_tag(tag, 50, 0).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn search_by_tag_paginated(
&self,
tag: &str,
limit: u32,
offset: u32,
) -> ClawResult<Vec<MemoryRecord>> {
let bounded_limit = limit.clamp(1, 1000);
MemoryStore::new(&self.pool)
.search_by_tag(tag, bounded_limit, offset)
.await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn fts_search(&self, query: &str) -> ClawResult<Vec<MemoryRecord>> {
MemoryStore::new(&self.pool).fts_search(query).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn expire_ttl_memories(&self) -> ClawResult<u64> {
let deleted = MemoryStore::new(&self.pool).expire_ttl().await?;
if deleted > 0 {
self.cache.lock().await.clear();
}
Ok(deleted)
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn start_session(&self) -> ClawResult<String> {
SessionLifecycleStore::new(&self.pool).start().await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn end_session(&self, session_id: &str) -> ClawResult<()> {
SessionLifecycleStore::new(&self.pool).end(session_id).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn get_session(&self, session_id: &str) -> ClawResult<Session> {
SessionLifecycleStore::new(&self.pool).get(session_id).await
}
#[tracing::instrument(skip(self, opts), fields(workspace_id = %self.config.workspace_id))]
pub async fn list_sessions(&self, opts: Option<ListOptions>) -> ClawResult<ListPage<Session>> {
let options = opts.unwrap_or_default();
SessionLifecycleStore::new(&self.pool)
.list_paginated(&options)
.await
}
#[tracing::instrument(skip(self, output), fields(workspace_id = %self.config.workspace_id))]
pub async fn record_tool_output(&self, output: &ToolOutputRecord) -> ClawResult<()> {
ToolOutputStore::new(&self.pool).insert(output).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn list_tool_outputs(&self, session_id: &str) -> ClawResult<Vec<ToolOutputRecord>> {
ToolOutputStore::new(&self.pool)
.get_by_session(session_id)
.await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn transaction(&self) -> ClawResult<crate::transaction::ClawTransaction<'_>> {
crate::transaction::ClawTransaction::begin(self).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn begin_transaction(&self) -> ClawResult<crate::transaction::ClawTransaction<'_>> {
crate::transaction::ClawTransaction::begin(self).await
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn snapshot(&self) -> ClawResult<PathBuf> {
let snapshot_dir = self
.config
.snapshot_dir
.as_ref()
.ok_or_else(|| ClawError::Config("snapshot_dir must be set".to_string()))?;
std::fs::create_dir_all(snapshot_dir)?;
sqlx::query("PRAGMA wal_checkpoint(FULL)")
.execute(&self.pool)
.await?;
let created_at_ms = Utc::now().timestamp_millis() as u64;
let final_path = snapshot_dir.join(format!("{created_at_ms}.db"));
let tmp_path = PathBuf::from(format!("{}.tmp", final_path.display()));
std::fs::copy(&self.config.db_path, &tmp_path).map_err(|e| {
ClawError::Snapshot(format!(
"failed to copy '{}' to '{}': {e}",
self.config.db_path.display(),
tmp_path.display()
))
})?;
std::fs::rename(&tmp_path, &final_path).map_err(|e| {
ClawError::Snapshot(format!(
"failed to rename '{}' to '{}': {e}",
tmp_path.display(),
final_path.display()
))
})?;
let size_bytes = std::fs::metadata(&final_path)
.map_err(|e| ClawError::Snapshot(format!("failed to stat snapshot file: {e}")))?
.len();
let blake3 = blake3_file_hex(&final_path)?;
let manifest = SnapshotManifest {
version: 1,
created_at_ms,
source_db: self.config.db_path.display().to_string(),
size_bytes,
blake3,
};
let manifest_path = manifest_path_for(&final_path);
let manifest_bytes = serde_json::to_vec_pretty(&manifest)
.map_err(|e| ClawError::Snapshot(format!("failed to serialize manifest: {e}")))?;
std::fs::write(&manifest_path, manifest_bytes).map_err(|e| {
ClawError::Snapshot(format!(
"failed to write manifest '{}': {e}",
manifest_path.display()
))
})?;
*self.last_snapshot_at.lock().await = Some(Utc::now());
Ok(final_path)
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn snapshot_create(&self) -> ClawResult<SnapshotMeta> {
let path = self.snapshot().await?;
let created_at_ms = path
.file_stem()
.and_then(|s| s.to_str())
.and_then(|s| s.parse::<u64>().ok())
.ok_or_else(|| {
ClawError::Snapshot("snapshot filename is not a unix-ms timestamp".to_string())
})?;
let created_at = Utc
.timestamp_millis_opt(created_at_ms as i64)
.single()
.ok_or_else(|| ClawError::Snapshot("invalid snapshot timestamp".to_string()))?;
let size_bytes = std::fs::metadata(&path)?.len();
let checksum = blake3_file_hex(&path)?;
Ok(SnapshotMeta {
path,
created_at,
size_bytes,
checksum,
})
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, snapshot = %snapshot_path.display()))]
pub async fn restore(&mut self, snapshot_path: &Path) -> ClawResult<()> {
verify_snapshot_integrity(snapshot_path)?;
self.pool.close().await;
let wal_path = PathBuf::from(format!("{}-wal", self.config.db_path.display()));
if wal_path.exists() {
std::fs::remove_file(&wal_path)?;
}
let shm_path = PathBuf::from(format!("{}-shm", self.config.db_path.display()));
if shm_path.exists() {
std::fs::remove_file(&shm_path)?;
}
std::fs::copy(snapshot_path, &self.config.db_path).map_err(|e| {
ClawError::Snapshot(format!(
"failed to restore snapshot '{}' into '{}': {e}",
snapshot_path.display(),
self.config.db_path.display()
))
})?;
self.pool = Self::connect_pool(&self.config, false).await?;
#[cfg(feature = "encryption")]
if let Some(key) = &self.config.encryption_key {
Self::apply_pragmas_key(&self.pool, key).await?;
}
self.migrate().await?;
self.cache.lock().await.clear();
Ok(())
}
pub fn list_snapshots(&self) -> ClawResult<Vec<SnapshotManifest>> {
let snapshot_dir = self
.config
.snapshot_dir
.as_ref()
.ok_or_else(|| ClawError::Config("snapshot_dir must be set".to_string()))?;
let mut manifests = Vec::new();
for entry in std::fs::read_dir(snapshot_dir)? {
let path = entry?.path();
if path
.file_name()
.and_then(|n| n.to_str())
.map(|n| n.ends_with(".manifest.json"))
.unwrap_or(false)
{
let bytes = std::fs::read(&path)?;
let manifest: SnapshotManifest = serde_json::from_slice(&bytes).map_err(|e| {
ClawError::Snapshot(format!("cannot parse manifest '{}': {e}", path.display()))
})?;
manifests.push(manifest);
}
}
manifests.sort_by(|a, b| b.created_at_ms.cmp(&a.created_at_ms));
Ok(manifests)
}
pub fn delete_snapshot(&self, path: &Path) -> ClawResult<()> {
if path.exists() {
std::fs::remove_file(path)?;
}
let manifest_path = manifest_path_for(path);
if manifest_path.exists() {
std::fs::remove_file(manifest_path)?;
}
Ok(())
}
#[cfg(feature = "encryption")]
#[tracing::instrument(skip(self, old_key, new_key), fields(workspace_id = %self.config.workspace_id))]
pub async fn rotate_key(&self, old_key: [u8; 32], new_key: [u8; 32]) -> ClawResult<()> {
Self::apply_pragmas_key(&self.pool, &old_key).await?;
let new_hex: String = new_key.iter().map(|b| format!("{b:02x}")).collect();
sqlx::query(&format!("PRAGMA rekey = \"x'{new_hex}'\""))
.execute(&self.pool)
.await?;
Ok(())
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn cache_stats(&self) -> CacheStats {
self.stats.lock().await.clone()
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn stats(&self) -> ClawResult<ClawStats> {
let (total_memories,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM memories")
.fetch_one(&self.pool)
.await?;
let (total_sessions,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sessions")
.fetch_one(&self.pool)
.await?;
let _lifetime_hits = self.cache_hits.load(Ordering::Relaxed);
let _lifetime_misses = self.cache_misses.load(Ordering::Relaxed);
let cache_size = self.cache.lock().await.len();
let cache_hit_rate = {
let window = self.read_window.lock().await;
if window.is_empty() {
0.0
} else {
let hits = window.iter().filter(|&&v| v).count();
hits as f64 / window.len() as f64
}
};
let db_size_bytes = std::fs::metadata(&self.config.db_path)
.map(|m| m.len())
.unwrap_or(0);
let wal_path = PathBuf::from(format!("{}-wal", self.config.db_path.display()));
let wal_size_bytes = std::fs::metadata(wal_path).map(|m| m.len()).unwrap_or(0);
let last_snapshot_at = *self.last_snapshot_at.lock().await;
Ok(ClawStats {
total_memories: total_memories as u64,
total_sessions: total_sessions as u64,
cache_hit_rate,
cache_size,
db_size_bytes,
wal_size_bytes,
last_snapshot_at,
})
}
#[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
pub async fn db_stats(&self) -> ClawResult<DbStats> {
let (mc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM memories")
.fetch_one(&self.pool)
.await?;
let (sc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sessions")
.fetch_one(&self.pool)
.await?;
let (tc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tool_output")
.fetch_one(&self.pool)
.await?;
Ok(DbStats {
memory_count: mc as u64,
session_count: sc as u64,
tool_output_count: tc as u64,
})
}
async fn connect_pool(config: &ClawConfig, create_if_missing: bool) -> ClawResult<SqlitePool> {
let db_url = format!("sqlite:{}", config.db_path.display());
let journal_mode = match config.journal_mode {
crate::config::JournalMode::WAL => SqliteJournalMode::Wal,
crate::config::JournalMode::Delete => SqliteJournalMode::Delete,
crate::config::JournalMode::Truncate => SqliteJournalMode::Truncate,
};
let connect_options = SqliteConnectOptions::from_str(&db_url)
.map_err(|e| ClawError::Config(format!("invalid database URL: {e}")))?
.create_if_missing(create_if_missing)
.journal_mode(journal_mode);
let pool = SqlitePoolOptions::new()
.max_connections(config.max_connections)
.connect_with(connect_options)
.await?;
Ok(pool)
}
async fn push_read_window(&self, hit: bool) {
let mut window = self.read_window.lock().await;
if window.len() >= 1000 {
window.pop_front();
}
window.push_back(hit);
}
#[cfg(feature = "encryption")]
async fn apply_pragmas_key(pool: &SqlitePool, key: &[u8; 32]) -> ClawResult<()> {
let hex: String = key.iter().map(|b| format!("{b:02x}")).collect();
sqlx::query(&format!("PRAGMA key = \"x'{hex}'\""))
.execute(pool)
.await?;
Ok(())
}
}