1use std::collections::VecDeque;
4use std::path::{Path, PathBuf};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9use chrono::{DateTime, TimeZone, Utc};
10use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions};
11use sqlx::SqlitePool;
12use tokio::sync::Mutex;
13use uuid::Uuid;
14
15use crate::cache::{CacheStats, ClawCache};
16use crate::config::ClawConfig;
17use crate::error::{ClawError, ClawResult};
18use crate::snapshot::{
19 blake3_file_hex, manifest_path_for, verify_snapshot_integrity, SnapshotManifest, SnapshotMeta,
20};
21use crate::store::memory::{ListOptions, ListPage, MemoryRecord, MemoryStore, MemoryType};
22use crate::store::session_lifecycle::{Session, SessionLifecycleStore};
23use crate::store::tool_output::{ToolOutputRecord, ToolOutputStore};
24
25#[derive(Debug, Clone)]
27pub struct DbStats {
28 pub memory_count: u64,
30 pub session_count: u64,
32 pub tool_output_count: u64,
34}
35
36#[derive(Debug, Clone)]
38pub struct ClawStats {
39 pub total_memories: u64,
41 pub total_sessions: u64,
43 pub cache_hit_rate: f64,
45 pub cache_size: usize,
47 pub db_size_bytes: u64,
49 pub wal_size_bytes: u64,
51 pub last_snapshot_at: Option<DateTime<Utc>>,
53}
54
55#[derive(Debug)]
57pub struct ClawEngine {
58 pub(crate) config: ClawConfig,
60 pub(crate) pool: SqlitePool,
62 pub(crate) cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
64 stats: Arc<Mutex<CacheStats>>,
66 last_snapshot_at: Arc<Mutex<Option<DateTime<Utc>>>>,
68 cache_hits: AtomicU64,
70 cache_misses: AtomicU64,
72 read_window: Arc<Mutex<VecDeque<bool>>>,
74}
75
76impl ClawEngine {
77 #[tracing::instrument(skip(config), fields(workspace_id = %config.workspace_id))]
79 pub async fn open(config: ClawConfig) -> ClawResult<Self> {
80 let pool = Self::connect_pool(&config, true).await?;
81
82 #[cfg(feature = "encryption")]
85 if let Some(key) = &config.encryption_key {
86 Self::apply_pragmas_key(&pool, key).await?;
87 }
88
89 let cache_cap = ((config.cache_size_mb * 1024 * 1024) / 512).max(64);
90 let cache = Arc::new(Mutex::new(ClawCache::new(cache_cap)?));
91
92 let engine = ClawEngine {
93 config,
94 pool,
95 cache,
96 stats: Arc::new(Mutex::new(CacheStats::new())),
97 last_snapshot_at: Arc::new(Mutex::new(None)),
98 cache_hits: AtomicU64::new(0),
99 cache_misses: AtomicU64::new(0),
100 read_window: Arc::new(Mutex::new(VecDeque::with_capacity(1000))),
101 };
102
103 if engine.config.auto_migrate {
104 engine.migrate().await?;
105 }
106
107 Ok(engine)
108 }
109
110 #[tracing::instrument(fields(workspace_id = "default"))]
112 pub async fn open_default() -> ClawResult<Self> {
113 ClawEngine::open(ClawConfig::default()).await
114 }
115
116 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
118 pub async fn migrate(&self) -> ClawResult<()> {
119 crate::schema::migrations::run_migrations(&self.pool).await
120 }
121
122 pub fn pool(&self) -> &SqlitePool {
124 &self.pool
125 }
126
127 pub fn config(&self) -> &ClawConfig {
129 &self.config
130 }
131
132 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
134 pub async fn close(self) {
135 self.pool.close().await;
136 }
137
138 #[tracing::instrument(skip(self, record), fields(workspace_id = %self.config.workspace_id, memory_id = %record.id))]
140 pub async fn insert_memory(&self, record: &MemoryRecord) -> ClawResult<Uuid> {
141 MemoryStore::new(&self.pool).insert(record).await?;
142 let mut cache = self.cache.lock().await;
143 let mut stats = self.stats.lock().await;
144 cache.insert(record.id, record.clone());
145 stats.insert_count += 1;
146 Ok(record.id)
147 }
148
149 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, memory_id = %id))]
151 pub async fn get_memory(&self, id: Uuid) -> ClawResult<MemoryRecord> {
152 {
153 let mut cache = self.cache.lock().await;
154 if let Some(record) = cache.get(&id) {
155 self.cache_hits.fetch_add(1, Ordering::Relaxed);
156 self.push_read_window(true).await;
157 let mut stats = self.stats.lock().await;
158 stats.record_hit();
159 return Ok(record.clone());
160 }
161 }
162
163 self.cache_misses.fetch_add(1, Ordering::Relaxed);
164 self.push_read_window(false).await;
165 {
166 let mut stats = self.stats.lock().await;
167 stats.record_miss();
168 }
169
170 let record = MemoryStore::new(&self.pool).get(id).await?;
171 let mut cache = self.cache.lock().await;
172 cache.insert(record.id, record.clone());
173 Ok(record)
174 }
175
176 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, memory_id = %id))]
178 pub async fn update_memory(&self, id: Uuid, content: &str) -> ClawResult<()> {
179 let updated_at = Utc::now();
180 MemoryStore::new(&self.pool)
181 .update_content(id, content, updated_at)
182 .await?;
183 self.cache.lock().await.invalidate(&id);
184 Ok(())
185 }
186
187 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, memory_id = %id))]
189 pub async fn delete_memory(&self, id: Uuid) -> ClawResult<()> {
190 MemoryStore::new(&self.pool).delete(id).await?;
191 self.cache.lock().await.invalidate(&id);
192 Ok(())
193 }
194
195 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
197 pub async fn list_memories(
198 &self,
199 type_filter: Option<MemoryType>,
200 ) -> ClawResult<Vec<MemoryRecord>> {
201 MemoryStore::new(&self.pool)
202 .list(type_filter.as_ref())
203 .await
204 }
205
206 #[tracing::instrument(skip(self, opts), fields(workspace_id = %self.config.workspace_id))]
208 pub async fn list_memories_paginated(
209 &self,
210 type_filter: Option<MemoryType>,
211 opts: ListOptions,
212 ) -> ClawResult<ListPage<MemoryRecord>> {
213 MemoryStore::new(&self.pool)
214 .list_paginated(type_filter.as_ref(), &opts)
215 .await
216 }
217
218 #[tracing::instrument(skip(self, opts), fields(workspace_id = %self.config.workspace_id))]
220 pub async fn get_memories_by_type(
221 &self,
222 memory_type: MemoryType,
223 opts: Option<ListOptions>,
224 ) -> ClawResult<ListPage<MemoryRecord>> {
225 let options = opts.unwrap_or_default();
226 MemoryStore::new(&self.pool)
227 .list_paginated(Some(&memory_type), &options)
228 .await
229 }
230
231 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
233 pub async fn search_by_tag(&self, tag: &str) -> ClawResult<Vec<MemoryRecord>> {
234 MemoryStore::new(&self.pool).search_by_tag(tag, 50, 0).await
235 }
236
237 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
239 pub async fn search_by_tag_paginated(
240 &self,
241 tag: &str,
242 limit: u32,
243 offset: u32,
244 ) -> ClawResult<Vec<MemoryRecord>> {
245 let bounded_limit = limit.clamp(1, 1000);
246 MemoryStore::new(&self.pool)
247 .search_by_tag(tag, bounded_limit, offset)
248 .await
249 }
250
251 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
253 pub async fn fts_search(&self, query: &str) -> ClawResult<Vec<MemoryRecord>> {
254 MemoryStore::new(&self.pool).fts_search(query).await
255 }
256
257 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
259 pub async fn expire_ttl_memories(&self) -> ClawResult<u64> {
260 let deleted = MemoryStore::new(&self.pool).expire_ttl().await?;
261 if deleted > 0 {
262 self.cache.lock().await.clear();
263 }
264 Ok(deleted)
265 }
266
267 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
269 pub async fn start_session(&self) -> ClawResult<String> {
270 SessionLifecycleStore::new(&self.pool).start().await
271 }
272
273 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
275 pub async fn end_session(&self, session_id: &str) -> ClawResult<()> {
276 SessionLifecycleStore::new(&self.pool).end(session_id).await
277 }
278
279 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
281 pub async fn get_session(&self, session_id: &str) -> ClawResult<Session> {
282 SessionLifecycleStore::new(&self.pool).get(session_id).await
283 }
284
285 #[tracing::instrument(skip(self, opts), fields(workspace_id = %self.config.workspace_id))]
287 pub async fn list_sessions(&self, opts: Option<ListOptions>) -> ClawResult<ListPage<Session>> {
288 let options = opts.unwrap_or_default();
289 SessionLifecycleStore::new(&self.pool)
290 .list_paginated(&options)
291 .await
292 }
293
294 #[tracing::instrument(skip(self, output), fields(workspace_id = %self.config.workspace_id))]
296 pub async fn record_tool_output(&self, output: &ToolOutputRecord) -> ClawResult<()> {
297 ToolOutputStore::new(&self.pool).insert(output).await
298 }
299
300 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
302 pub async fn list_tool_outputs(&self, session_id: &str) -> ClawResult<Vec<ToolOutputRecord>> {
303 ToolOutputStore::new(&self.pool)
304 .get_by_session(session_id)
305 .await
306 }
307
308 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
310 pub async fn transaction(&self) -> ClawResult<crate::transaction::ClawTransaction<'_>> {
311 crate::transaction::ClawTransaction::begin(self).await
312 }
313
314 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
316 pub async fn begin_transaction(&self) -> ClawResult<crate::transaction::ClawTransaction<'_>> {
317 crate::transaction::ClawTransaction::begin(self).await
318 }
319
320 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
322 pub async fn snapshot(&self) -> ClawResult<PathBuf> {
323 let snapshot_dir = self
324 .config
325 .snapshot_dir
326 .as_ref()
327 .ok_or_else(|| ClawError::Config("snapshot_dir must be set".to_string()))?;
328
329 std::fs::create_dir_all(snapshot_dir)?;
330
331 sqlx::query("PRAGMA wal_checkpoint(FULL)")
333 .execute(&self.pool)
334 .await?;
335
336 let created_at_ms = Utc::now().timestamp_millis() as u64;
337 let final_path = snapshot_dir.join(format!("{created_at_ms}.db"));
338 let tmp_path = PathBuf::from(format!("{}.tmp", final_path.display()));
339
340 std::fs::copy(&self.config.db_path, &tmp_path).map_err(|e| {
341 ClawError::Snapshot(format!(
342 "failed to copy '{}' to '{}': {e}",
343 self.config.db_path.display(),
344 tmp_path.display()
345 ))
346 })?;
347
348 std::fs::rename(&tmp_path, &final_path).map_err(|e| {
349 ClawError::Snapshot(format!(
350 "failed to rename '{}' to '{}': {e}",
351 tmp_path.display(),
352 final_path.display()
353 ))
354 })?;
355
356 let size_bytes = std::fs::metadata(&final_path)
357 .map_err(|e| ClawError::Snapshot(format!("failed to stat snapshot file: {e}")))?
358 .len();
359
360 let blake3 = blake3_file_hex(&final_path)?;
361 let manifest = SnapshotManifest {
362 version: 1,
363 created_at_ms,
364 source_db: self.config.db_path.display().to_string(),
365 size_bytes,
366 blake3,
367 };
368
369 let manifest_path = manifest_path_for(&final_path);
370 let manifest_bytes = serde_json::to_vec_pretty(&manifest)
371 .map_err(|e| ClawError::Snapshot(format!("failed to serialize manifest: {e}")))?;
372 std::fs::write(&manifest_path, manifest_bytes).map_err(|e| {
373 ClawError::Snapshot(format!(
374 "failed to write manifest '{}': {e}",
375 manifest_path.display()
376 ))
377 })?;
378
379 *self.last_snapshot_at.lock().await = Some(Utc::now());
380 Ok(final_path)
381 }
382
383 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
385 pub async fn snapshot_create(&self) -> ClawResult<SnapshotMeta> {
386 let path = self.snapshot().await?;
387 let created_at_ms = path
388 .file_stem()
389 .and_then(|s| s.to_str())
390 .and_then(|s| s.parse::<u64>().ok())
391 .ok_or_else(|| {
392 ClawError::Snapshot("snapshot filename is not a unix-ms timestamp".to_string())
393 })?;
394 let created_at = Utc
395 .timestamp_millis_opt(created_at_ms as i64)
396 .single()
397 .ok_or_else(|| ClawError::Snapshot("invalid snapshot timestamp".to_string()))?;
398 let size_bytes = std::fs::metadata(&path)?.len();
399 let checksum = blake3_file_hex(&path)?;
400 Ok(SnapshotMeta {
401 path,
402 created_at,
403 size_bytes,
404 checksum,
405 })
406 }
407
408 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id, snapshot = %snapshot_path.display()))]
410 pub async fn restore(&mut self, snapshot_path: &Path) -> ClawResult<()> {
411 verify_snapshot_integrity(snapshot_path)?;
412
413 self.pool.close().await;
414
415 let wal_path = PathBuf::from(format!("{}-wal", self.config.db_path.display()));
416 if wal_path.exists() {
417 std::fs::remove_file(&wal_path)?;
418 }
419 let shm_path = PathBuf::from(format!("{}-shm", self.config.db_path.display()));
420 if shm_path.exists() {
421 std::fs::remove_file(&shm_path)?;
422 }
423
424 std::fs::copy(snapshot_path, &self.config.db_path).map_err(|e| {
425 ClawError::Snapshot(format!(
426 "failed to restore snapshot '{}' into '{}': {e}",
427 snapshot_path.display(),
428 self.config.db_path.display()
429 ))
430 })?;
431
432 self.pool = Self::connect_pool(&self.config, false).await?;
433
434 #[cfg(feature = "encryption")]
435 if let Some(key) = &self.config.encryption_key {
436 Self::apply_pragmas_key(&self.pool, key).await?;
437 }
438
439 self.migrate().await?;
440 self.cache.lock().await.clear();
441
442 Ok(())
443 }
444
445 pub fn list_snapshots(&self) -> ClawResult<Vec<SnapshotManifest>> {
447 let snapshot_dir = self
448 .config
449 .snapshot_dir
450 .as_ref()
451 .ok_or_else(|| ClawError::Config("snapshot_dir must be set".to_string()))?;
452
453 let mut manifests = Vec::new();
454 for entry in std::fs::read_dir(snapshot_dir)? {
455 let path = entry?.path();
456 if path
457 .file_name()
458 .and_then(|n| n.to_str())
459 .map(|n| n.ends_with(".manifest.json"))
460 .unwrap_or(false)
461 {
462 let bytes = std::fs::read(&path)?;
463 let manifest: SnapshotManifest = serde_json::from_slice(&bytes).map_err(|e| {
464 ClawError::Snapshot(format!("cannot parse manifest '{}': {e}", path.display()))
465 })?;
466 manifests.push(manifest);
467 }
468 }
469
470 manifests.sort_by(|a, b| b.created_at_ms.cmp(&a.created_at_ms));
471 Ok(manifests)
472 }
473
474 pub fn delete_snapshot(&self, path: &Path) -> ClawResult<()> {
476 if path.exists() {
477 std::fs::remove_file(path)?;
478 }
479 let manifest_path = manifest_path_for(path);
480 if manifest_path.exists() {
481 std::fs::remove_file(manifest_path)?;
482 }
483 Ok(())
484 }
485
486 #[cfg(feature = "encryption")]
488 #[tracing::instrument(skip(self, old_key, new_key), fields(workspace_id = %self.config.workspace_id))]
489 pub async fn rotate_key(&self, old_key: [u8; 32], new_key: [u8; 32]) -> ClawResult<()> {
490 Self::apply_pragmas_key(&self.pool, &old_key).await?;
491 let new_hex: String = new_key.iter().map(|b| format!("{b:02x}")).collect();
492 sqlx::query(&format!("PRAGMA rekey = \"x'{new_hex}'\""))
493 .execute(&self.pool)
494 .await?;
495 Ok(())
496 }
497
498 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
500 pub async fn cache_stats(&self) -> CacheStats {
501 self.stats.lock().await.clone()
502 }
503
504 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
506 pub async fn stats(&self) -> ClawResult<ClawStats> {
507 let (total_memories,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM memories")
508 .fetch_one(&self.pool)
509 .await?;
510 let (total_sessions,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sessions")
511 .fetch_one(&self.pool)
512 .await?;
513
514 let _lifetime_hits = self.cache_hits.load(Ordering::Relaxed);
515 let _lifetime_misses = self.cache_misses.load(Ordering::Relaxed);
516
517 let cache_size = self.cache.lock().await.len();
518 let cache_hit_rate = {
519 let window = self.read_window.lock().await;
520 if window.is_empty() {
521 0.0
522 } else {
523 let hits = window.iter().filter(|&&v| v).count();
524 hits as f64 / window.len() as f64
525 }
526 };
527
528 let db_size_bytes = std::fs::metadata(&self.config.db_path)
529 .map(|m| m.len())
530 .unwrap_or(0);
531 let wal_path = PathBuf::from(format!("{}-wal", self.config.db_path.display()));
532 let wal_size_bytes = std::fs::metadata(wal_path).map(|m| m.len()).unwrap_or(0);
533 let last_snapshot_at = *self.last_snapshot_at.lock().await;
534
535 Ok(ClawStats {
536 total_memories: total_memories as u64,
537 total_sessions: total_sessions as u64,
538 cache_hit_rate,
539 cache_size,
540 db_size_bytes,
541 wal_size_bytes,
542 last_snapshot_at,
543 })
544 }
545
546 #[tracing::instrument(skip(self), fields(workspace_id = %self.config.workspace_id))]
548 pub async fn db_stats(&self) -> ClawResult<DbStats> {
549 let (mc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM memories")
550 .fetch_one(&self.pool)
551 .await?;
552 let (sc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sessions")
553 .fetch_one(&self.pool)
554 .await?;
555 let (tc,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tool_output")
556 .fetch_one(&self.pool)
557 .await?;
558 Ok(DbStats {
559 memory_count: mc as u64,
560 session_count: sc as u64,
561 tool_output_count: tc as u64,
562 })
563 }
564
565 async fn connect_pool(config: &ClawConfig, create_if_missing: bool) -> ClawResult<SqlitePool> {
566 let db_url = format!("sqlite:{}", config.db_path.display());
567 let journal_mode = match config.journal_mode {
568 crate::config::JournalMode::WAL => SqliteJournalMode::Wal,
569 crate::config::JournalMode::Delete => SqliteJournalMode::Delete,
570 crate::config::JournalMode::Truncate => SqliteJournalMode::Truncate,
571 };
572
573 let connect_options = SqliteConnectOptions::from_str(&db_url)
574 .map_err(|e| ClawError::Config(format!("invalid database URL: {e}")))?
575 .create_if_missing(create_if_missing)
576 .journal_mode(journal_mode);
577
578 let pool = SqlitePoolOptions::new()
579 .max_connections(config.max_connections)
580 .connect_with(connect_options)
581 .await?;
582
583 Ok(pool)
584 }
585
586 async fn push_read_window(&self, hit: bool) {
587 let mut window = self.read_window.lock().await;
588 if window.len() >= 1000 {
589 window.pop_front();
590 }
591 window.push_back(hit);
592 }
593
594 #[cfg(feature = "encryption")]
595 async fn apply_pragmas_key(pool: &SqlitePool, key: &[u8; 32]) -> ClawResult<()> {
596 let hex: String = key.iter().map(|b| format!("{b:02x}")).collect();
597 sqlx::query(&format!("PRAGMA key = \"x'{hex}'\""))
598 .execute(pool)
599 .await?;
600 Ok(())
601 }
602}