Skip to main content

claw_core/store/
memory.rs

1//! Memory record store for claw-core.
2//!
3//! The `memories` table is the primary store for persistent AI agent memories.
4//! Records are classified by [`MemoryType`], can carry arbitrary tags for
5//! keyword search, and optionally expire after a configurable TTL.
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use sqlx::SqlitePool;
10use uuid::Uuid;
11
12use crate::error::{ClawError, ClawResult};
13
14/// Options controlling pagination for list queries.
15///
16/// `limit` defaults to `50` and is clamped to a maximum of `1000`.
17#[derive(Debug, Clone)]
18pub struct ListOptions {
19    /// Maximum number of items to return.
20    pub limit: u32,
21    /// Opaque cursor returned by the previous call.
22    pub cursor: Option<String>,
23}
24
25impl Default for ListOptions {
26    fn default() -> Self {
27        Self {
28            limit: 50,
29            cursor: None,
30        }
31    }
32}
33
34impl ListOptions {
35    /// Return the validated page size.
36    pub fn validated_limit(&self) -> u32 {
37        self.limit.clamp(1, 1000)
38    }
39}
40
41/// A page of list results with an optional continuation cursor.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ListPage<T> {
44    /// Page items.
45    pub items: Vec<T>,
46    /// Cursor to request the next page, if any.
47    pub next_cursor: Option<String>,
48}
49
50/// The logical classification of a memory record.
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52pub enum MemoryType {
53    /// Factual world knowledge (e.g. "Paris is the capital of France").
54    Semantic,
55    /// Event-based memories tied to specific experiences.
56    Episodic,
57    /// Short-lived working memory for active reasoning.
58    Working,
59    /// Skill or procedure memory.
60    Procedural,
61}
62
63impl MemoryType {
64    /// Return the string representation stored in the database.
65    pub fn as_str(&self) -> &'static str {
66        match self {
67            MemoryType::Semantic => "semantic",
68            MemoryType::Episodic => "episodic",
69            MemoryType::Working => "working",
70            MemoryType::Procedural => "procedural",
71        }
72    }
73}
74
75impl std::fmt::Display for MemoryType {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.write_str(self.as_str())
78    }
79}
80
81impl std::str::FromStr for MemoryType {
82    type Err = ClawError;
83
84    fn from_str(s: &str) -> Result<Self, Self::Err> {
85        match s {
86            "semantic" => Ok(MemoryType::Semantic),
87            "episodic" => Ok(MemoryType::Episodic),
88            "working" => Ok(MemoryType::Working),
89            "procedural" => Ok(MemoryType::Procedural),
90            other => Err(ClawError::InvalidInput(format!(
91                "unknown memory type: {other}"
92            ))),
93        }
94    }
95}
96
97/// A persistent memory record stored in the `memories` table.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct MemoryRecord {
100    /// Unique record identifier.
101    pub id: Uuid,
102    /// Natural-language content of the memory.
103    pub content: String,
104    /// The logical type of this memory.
105    pub memory_type: MemoryType,
106    /// Searchable tags attached to this record.
107    pub tags: Vec<String>,
108    /// Optional TTL in seconds from `created_at`. `None` means no expiry.
109    pub ttl_seconds: Option<u64>,
110    /// Timestamp when this record was first created.
111    pub created_at: DateTime<Utc>,
112    /// Timestamp when this record was last modified.
113    pub updated_at: DateTime<Utc>,
114}
115
116impl MemoryRecord {
117    /// Create a new [`MemoryRecord`] with a fresh UUID and current UTC timestamps.
118    pub fn new(
119        content: impl Into<String>,
120        memory_type: MemoryType,
121        tags: Vec<String>,
122        ttl_seconds: Option<u64>,
123    ) -> Self {
124        let now = Utc::now();
125        MemoryRecord {
126            id: Uuid::new_v4(),
127            content: content.into(),
128            memory_type,
129            tags,
130            ttl_seconds,
131            created_at: now,
132            updated_at: now,
133        }
134    }
135}
136
137/// Data-access object for the `memories` and `memories_fts` tables.
138#[derive(Debug)]
139pub struct MemoryStore<'a> {
140    pool: &'a SqlitePool,
141}
142
143impl<'a> MemoryStore<'a> {
144    /// Create a new [`MemoryStore`] bound to `pool`.
145    pub fn new(pool: &'a SqlitePool) -> Self {
146        MemoryStore { pool }
147    }
148
149    /// Insert a new [`MemoryRecord`] into the database.
150    ///
151    /// The write is atomic across `memories`, `memories_fts`, and `memory_tags`.
152    pub async fn insert(&self, record: &MemoryRecord) -> ClawResult<()> {
153        let tags = serde_json::to_string(&record.tags)?;
154        let mut tx = self.pool.begin().await?;
155
156        sqlx::query(
157            "INSERT INTO memories \
158             (id, content, memory_type, tags, ttl_seconds, created_at, updated_at) \
159             VALUES (?, ?, ?, ?, ?, ?, ?)",
160        )
161        .bind(record.id.to_string())
162        .bind(&record.content)
163        .bind(record.memory_type.as_str())
164        .bind(&tags)
165        .bind(record.ttl_seconds.map(|s| s as i64))
166        .bind(record.created_at.to_rfc3339())
167        .bind(record.updated_at.to_rfc3339())
168        .execute(&mut *tx)
169        .await?;
170
171        sqlx::query(
172            "INSERT INTO memories_fts(rowid, content) \
173               VALUES (last_insert_rowid(), ?)",
174        )
175        .bind(&record.content)
176        .execute(&mut *tx)
177        .await?;
178
179        for tag in &record.tags {
180            sqlx::query("INSERT OR REPLACE INTO memory_tags(memory_id, tag) VALUES (?, ?)")
181                .bind(record.id.to_string())
182                .bind(tag)
183                .execute(&mut *tx)
184                .await?;
185        }
186
187        tx.commit().await?;
188        Ok(())
189    }
190
191    /// Fetch a [`MemoryRecord`] by its `id`.
192    pub async fn get(&self, id: Uuid) -> ClawResult<MemoryRecord> {
193        let row =
194            sqlx::query_as::<_, (String, String, String, String, Option<i64>, String, String)>(
195                "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
196                 FROM memories WHERE id = ?",
197            )
198            .bind(id.to_string())
199            .fetch_optional(self.pool)
200            .await?;
201
202        let row = row.ok_or_else(|| ClawError::NotFound {
203            entity: "MemoryRecord".to_string(),
204            id: id.to_string(),
205        })?;
206
207        Self::row_to_record(row)
208    }
209
210    /// Update the `content` and `updated_at` of a [`MemoryRecord`].
211    ///
212    /// The write is atomic across `memories` and `memories_fts`.
213    pub async fn update_content(
214        &self,
215        id: Uuid,
216        content: &str,
217        updated_at: DateTime<Utc>,
218    ) -> ClawResult<()> {
219        let mut tx = self.pool.begin().await?;
220
221        let affected = sqlx::query("UPDATE memories SET content = ?, updated_at = ? WHERE id = ?")
222            .bind(content)
223            .bind(updated_at.to_rfc3339())
224            .bind(id.to_string())
225            .execute(&mut *tx)
226            .await?
227            .rows_affected();
228
229        if affected == 0 {
230            return Err(ClawError::NotFound {
231                entity: "MemoryRecord".to_string(),
232                id: id.to_string(),
233            });
234        }
235
236        sqlx::query(
237            "DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)",
238        )
239        .bind(id.to_string())
240        .execute(&mut *tx)
241        .await?;
242
243        sqlx::query(
244            "INSERT INTO memories_fts(rowid, content) \
245             VALUES ((SELECT rowid FROM memories WHERE id = ?), ?)",
246        )
247        .bind(id.to_string())
248        .bind(content)
249        .execute(&mut *tx)
250        .await?;
251
252        tx.commit().await?;
253        Ok(())
254    }
255
256    /// Delete a [`MemoryRecord`] by its `id`.
257    ///
258    /// `memory_tags` rows are removed by `ON DELETE CASCADE`.
259    pub async fn delete(&self, id: Uuid) -> ClawResult<()> {
260        let mut tx = self.pool.begin().await?;
261
262        sqlx::query(
263            "DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)",
264        )
265        .bind(id.to_string())
266        .execute(&mut *tx)
267        .await?;
268
269        let affected = sqlx::query("DELETE FROM memories WHERE id = ?")
270            .bind(id.to_string())
271            .execute(&mut *tx)
272            .await?
273            .rows_affected();
274
275        if affected == 0 {
276            return Err(ClawError::NotFound {
277                entity: "MemoryRecord".to_string(),
278                id: id.to_string(),
279            });
280        }
281
282        tx.commit().await?;
283        Ok(())
284    }
285
286    /// List all [`MemoryRecord`]s, optionally filtered by [`MemoryType`].
287    pub async fn list(&self, type_filter: Option<&MemoryType>) -> ClawResult<Vec<MemoryRecord>> {
288        #[allow(clippy::type_complexity)]
289        let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
290            match type_filter {
291                Some(mt) => sqlx::query_as(
292                    "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
293                     FROM memories WHERE memory_type = ? ORDER BY created_at DESC, id DESC",
294                )
295                .bind(mt.as_str())
296                .fetch_all(self.pool)
297                .await?,
298                None => sqlx::query_as(
299                    "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
300                     FROM memories ORDER BY created_at DESC, id DESC",
301                )
302                .fetch_all(self.pool)
303                .await?,
304            };
305
306        rows.into_iter().map(Self::row_to_record).collect()
307    }
308
309    /// Search by exact tag using the indexed `memory_tags` table.
310    pub async fn search_by_tag(
311        &self,
312        tag: &str,
313        limit: u32,
314        offset: u32,
315    ) -> ClawResult<Vec<MemoryRecord>> {
316        #[allow(clippy::type_complexity)]
317        let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
318            sqlx::query_as(
319                "SELECT m.id, m.content, m.memory_type, m.tags, m.ttl_seconds, \
320                        m.created_at, m.updated_at \
321                 FROM memories m \
322                 JOIN memory_tags t ON m.id = t.memory_id \
323                 WHERE t.tag = ? \
324                 ORDER BY m.created_at DESC LIMIT ? OFFSET ?",
325            )
326            .bind(tag)
327            .bind(limit as i64)
328            .bind(offset as i64)
329            .fetch_all(self.pool)
330            .await?;
331
332        rows.into_iter().map(Self::row_to_record).collect()
333    }
334
335    /// List memories with keyset pagination.
336    pub async fn list_paginated(
337        &self,
338        type_filter: Option<&MemoryType>,
339        opts: &ListOptions,
340    ) -> ClawResult<ListPage<MemoryRecord>> {
341        let limit = opts.validated_limit() as i64;
342        let fetch = limit.saturating_add(1);
343
344        #[allow(clippy::type_complexity)]
345        let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
346            match (&opts.cursor, type_filter) {
347                (None, None) => sqlx::query_as(
348                    "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
349                         FROM memories ORDER BY created_at DESC, id DESC LIMIT ?",
350                )
351                .bind(fetch)
352                .fetch_all(self.pool)
353                .await?,
354                (None, Some(mt)) => sqlx::query_as(
355                    "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
356                         FROM memories WHERE memory_type = ? \
357                         ORDER BY created_at DESC, id DESC LIMIT ?",
358                )
359                .bind(mt.as_str())
360                .bind(fetch)
361                .fetch_all(self.pool)
362                .await?,
363                (Some(cursor), None) => sqlx::query_as(
364                    "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
365                         FROM memories \
366                         WHERE (created_at, id) < \
367                             (SELECT created_at, id FROM memories WHERE id = ?) \
368                         ORDER BY created_at DESC, id DESC LIMIT ?",
369                )
370                .bind(cursor)
371                .bind(fetch)
372                .fetch_all(self.pool)
373                .await?,
374                (Some(cursor), Some(mt)) => sqlx::query_as(
375                    "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
376                         FROM memories \
377                         WHERE memory_type = ? \
378                           AND (created_at, id) < \
379                               (SELECT created_at, id FROM memories WHERE id = ?) \
380                         ORDER BY created_at DESC, id DESC LIMIT ?",
381                )
382                .bind(mt.as_str())
383                .bind(cursor)
384                .bind(fetch)
385                .fetch_all(self.pool)
386                .await?,
387            };
388
389        let has_more = rows.len() as i64 > limit;
390        let page_rows = if has_more {
391            &rows[..limit as usize]
392        } else {
393            rows.as_slice()
394        };
395
396        let items = page_rows
397            .iter()
398            .cloned()
399            .map(Self::row_to_record)
400            .collect::<ClawResult<Vec<_>>>()?;
401
402        let next_cursor = if has_more {
403            items.last().map(|item| item.id.to_string())
404        } else {
405            None
406        };
407
408        Ok(ListPage { items, next_cursor })
409    }
410
411    /// Full-text search over the `memories_fts` index.
412    pub async fn fts_search(&self, query: &str) -> ClawResult<Vec<MemoryRecord>> {
413        #[allow(clippy::type_complexity)]
414        let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
415            sqlx::query_as(
416                "SELECT m.id, m.content, m.memory_type, m.tags, m.ttl_seconds, m.created_at, m.updated_at \
417                 FROM memories_fts f \
418                 JOIN memories m ON m.rowid = f.rowid \
419                 WHERE memories_fts MATCH ? \
420                 ORDER BY m.created_at DESC, m.id DESC",
421            )
422            .bind(query)
423            .fetch_all(self.pool)
424            .await?;
425
426        rows.into_iter().map(Self::row_to_record).collect()
427    }
428
429    /// Delete all records whose TTL has expired.
430    pub async fn expire_ttl(&self) -> ClawResult<u64> {
431        let rows: Vec<(String, Option<i64>, String)> = sqlx::query_as(
432            "SELECT id, ttl_seconds, created_at FROM memories WHERE ttl_seconds IS NOT NULL",
433        )
434        .fetch_all(self.pool)
435        .await?;
436
437        let now = Utc::now();
438        let mut deleted = 0u64;
439
440        for (id_str, ttl_secs, created_at_str) in rows {
441            if let Some(ttl) = ttl_secs {
442                let created_at = DateTime::parse_from_rfc3339(&created_at_str)
443                    .map_err(|e| ClawError::Store(e.to_string()))?
444                    .with_timezone(&Utc);
445                let expiry = created_at + chrono::Duration::seconds(ttl);
446                if now >= expiry {
447                    sqlx::query("DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)")
448                        .bind(&id_str)
449                        .execute(self.pool)
450                        .await?;
451                    sqlx::query("DELETE FROM memories WHERE id = ?")
452                        .bind(&id_str)
453                        .execute(self.pool)
454                        .await?;
455                    deleted += 1;
456                }
457            }
458        }
459
460        Ok(deleted)
461    }
462
463    fn row_to_record(
464        row: (String, String, String, String, Option<i64>, String, String),
465    ) -> ClawResult<MemoryRecord> {
466        let (id_str, content, memory_type_str, tags_str, ttl_secs, created_at_str, updated_at_str) =
467            row;
468        Ok(MemoryRecord {
469            id: Uuid::parse_str(&id_str).map_err(|e| ClawError::Store(e.to_string()))?,
470            content,
471            memory_type: memory_type_str.parse()?,
472            tags: serde_json::from_str(&tags_str)?,
473            ttl_seconds: ttl_secs.map(|s| s as u64),
474            created_at: DateTime::parse_from_rfc3339(&created_at_str)
475                .map_err(|e| ClawError::Store(e.to_string()))?
476                .with_timezone(&Utc),
477            updated_at: DateTime::parse_from_rfc3339(&updated_at_str)
478                .map_err(|e| ClawError::Store(e.to_string()))?
479                .with_timezone(&Utc),
480        })
481    }
482}