1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use sqlx::SqlitePool;
10use uuid::Uuid;
11
12use crate::error::{ClawError, ClawResult};
13
14#[derive(Debug, Clone)]
18pub struct ListOptions {
19 pub limit: u32,
21 pub cursor: Option<String>,
23}
24
25impl Default for ListOptions {
26 fn default() -> Self {
27 Self {
28 limit: 50,
29 cursor: None,
30 }
31 }
32}
33
34impl ListOptions {
35 pub fn validated_limit(&self) -> u32 {
37 self.limit.clamp(1, 1000)
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ListPage<T> {
44 pub items: Vec<T>,
46 pub next_cursor: Option<String>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52pub enum MemoryType {
53 Semantic,
55 Episodic,
57 Working,
59 Procedural,
61}
62
63impl MemoryType {
64 pub fn as_str(&self) -> &'static str {
66 match self {
67 MemoryType::Semantic => "semantic",
68 MemoryType::Episodic => "episodic",
69 MemoryType::Working => "working",
70 MemoryType::Procedural => "procedural",
71 }
72 }
73}
74
75impl std::fmt::Display for MemoryType {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.write_str(self.as_str())
78 }
79}
80
81impl std::str::FromStr for MemoryType {
82 type Err = ClawError;
83
84 fn from_str(s: &str) -> Result<Self, Self::Err> {
85 match s {
86 "semantic" => Ok(MemoryType::Semantic),
87 "episodic" => Ok(MemoryType::Episodic),
88 "working" => Ok(MemoryType::Working),
89 "procedural" => Ok(MemoryType::Procedural),
90 other => Err(ClawError::InvalidInput(format!(
91 "unknown memory type: {other}"
92 ))),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct MemoryRecord {
100 pub id: Uuid,
102 pub content: String,
104 pub memory_type: MemoryType,
106 pub tags: Vec<String>,
108 pub ttl_seconds: Option<u64>,
110 pub created_at: DateTime<Utc>,
112 pub updated_at: DateTime<Utc>,
114}
115
116impl MemoryRecord {
117 pub fn new(
119 content: impl Into<String>,
120 memory_type: MemoryType,
121 tags: Vec<String>,
122 ttl_seconds: Option<u64>,
123 ) -> Self {
124 let now = Utc::now();
125 MemoryRecord {
126 id: Uuid::new_v4(),
127 content: content.into(),
128 memory_type,
129 tags,
130 ttl_seconds,
131 created_at: now,
132 updated_at: now,
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct MemoryStore<'a> {
140 pool: &'a SqlitePool,
141}
142
143impl<'a> MemoryStore<'a> {
144 pub fn new(pool: &'a SqlitePool) -> Self {
146 MemoryStore { pool }
147 }
148
149 pub async fn insert(&self, record: &MemoryRecord) -> ClawResult<()> {
153 let tags = serde_json::to_string(&record.tags)?;
154 let mut tx = self.pool.begin().await?;
155
156 sqlx::query(
157 "INSERT INTO memories \
158 (id, content, memory_type, tags, ttl_seconds, created_at, updated_at) \
159 VALUES (?, ?, ?, ?, ?, ?, ?)",
160 )
161 .bind(record.id.to_string())
162 .bind(&record.content)
163 .bind(record.memory_type.as_str())
164 .bind(&tags)
165 .bind(record.ttl_seconds.map(|s| s as i64))
166 .bind(record.created_at.to_rfc3339())
167 .bind(record.updated_at.to_rfc3339())
168 .execute(&mut *tx)
169 .await?;
170
171 sqlx::query(
172 "INSERT INTO memories_fts(rowid, content) \
173 VALUES (last_insert_rowid(), ?)",
174 )
175 .bind(&record.content)
176 .execute(&mut *tx)
177 .await?;
178
179 for tag in &record.tags {
180 sqlx::query("INSERT OR REPLACE INTO memory_tags(memory_id, tag) VALUES (?, ?)")
181 .bind(record.id.to_string())
182 .bind(tag)
183 .execute(&mut *tx)
184 .await?;
185 }
186
187 tx.commit().await?;
188 Ok(())
189 }
190
191 pub async fn get(&self, id: Uuid) -> ClawResult<MemoryRecord> {
193 let row =
194 sqlx::query_as::<_, (String, String, String, String, Option<i64>, String, String)>(
195 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
196 FROM memories WHERE id = ?",
197 )
198 .bind(id.to_string())
199 .fetch_optional(self.pool)
200 .await?;
201
202 let row = row.ok_or_else(|| ClawError::NotFound {
203 entity: "MemoryRecord".to_string(),
204 id: id.to_string(),
205 })?;
206
207 Self::row_to_record(row)
208 }
209
210 pub async fn update_content(
214 &self,
215 id: Uuid,
216 content: &str,
217 updated_at: DateTime<Utc>,
218 ) -> ClawResult<()> {
219 let mut tx = self.pool.begin().await?;
220
221 let affected = sqlx::query("UPDATE memories SET content = ?, updated_at = ? WHERE id = ?")
222 .bind(content)
223 .bind(updated_at.to_rfc3339())
224 .bind(id.to_string())
225 .execute(&mut *tx)
226 .await?
227 .rows_affected();
228
229 if affected == 0 {
230 return Err(ClawError::NotFound {
231 entity: "MemoryRecord".to_string(),
232 id: id.to_string(),
233 });
234 }
235
236 sqlx::query(
237 "DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)",
238 )
239 .bind(id.to_string())
240 .execute(&mut *tx)
241 .await?;
242
243 sqlx::query(
244 "INSERT INTO memories_fts(rowid, content) \
245 VALUES ((SELECT rowid FROM memories WHERE id = ?), ?)",
246 )
247 .bind(id.to_string())
248 .bind(content)
249 .execute(&mut *tx)
250 .await?;
251
252 tx.commit().await?;
253 Ok(())
254 }
255
256 pub async fn delete(&self, id: Uuid) -> ClawResult<()> {
260 let mut tx = self.pool.begin().await?;
261
262 sqlx::query(
263 "DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)",
264 )
265 .bind(id.to_string())
266 .execute(&mut *tx)
267 .await?;
268
269 let affected = sqlx::query("DELETE FROM memories WHERE id = ?")
270 .bind(id.to_string())
271 .execute(&mut *tx)
272 .await?
273 .rows_affected();
274
275 if affected == 0 {
276 return Err(ClawError::NotFound {
277 entity: "MemoryRecord".to_string(),
278 id: id.to_string(),
279 });
280 }
281
282 tx.commit().await?;
283 Ok(())
284 }
285
286 pub async fn list(&self, type_filter: Option<&MemoryType>) -> ClawResult<Vec<MemoryRecord>> {
288 #[allow(clippy::type_complexity)]
289 let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
290 match type_filter {
291 Some(mt) => sqlx::query_as(
292 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
293 FROM memories WHERE memory_type = ? ORDER BY created_at DESC, id DESC",
294 )
295 .bind(mt.as_str())
296 .fetch_all(self.pool)
297 .await?,
298 None => sqlx::query_as(
299 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
300 FROM memories ORDER BY created_at DESC, id DESC",
301 )
302 .fetch_all(self.pool)
303 .await?,
304 };
305
306 rows.into_iter().map(Self::row_to_record).collect()
307 }
308
309 pub async fn search_by_tag(
311 &self,
312 tag: &str,
313 limit: u32,
314 offset: u32,
315 ) -> ClawResult<Vec<MemoryRecord>> {
316 #[allow(clippy::type_complexity)]
317 let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
318 sqlx::query_as(
319 "SELECT m.id, m.content, m.memory_type, m.tags, m.ttl_seconds, \
320 m.created_at, m.updated_at \
321 FROM memories m \
322 JOIN memory_tags t ON m.id = t.memory_id \
323 WHERE t.tag = ? \
324 ORDER BY m.created_at DESC LIMIT ? OFFSET ?",
325 )
326 .bind(tag)
327 .bind(limit as i64)
328 .bind(offset as i64)
329 .fetch_all(self.pool)
330 .await?;
331
332 rows.into_iter().map(Self::row_to_record).collect()
333 }
334
335 pub async fn list_paginated(
337 &self,
338 type_filter: Option<&MemoryType>,
339 opts: &ListOptions,
340 ) -> ClawResult<ListPage<MemoryRecord>> {
341 let limit = opts.validated_limit() as i64;
342 let fetch = limit.saturating_add(1);
343
344 #[allow(clippy::type_complexity)]
345 let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
346 match (&opts.cursor, type_filter) {
347 (None, None) => sqlx::query_as(
348 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
349 FROM memories ORDER BY created_at DESC, id DESC LIMIT ?",
350 )
351 .bind(fetch)
352 .fetch_all(self.pool)
353 .await?,
354 (None, Some(mt)) => sqlx::query_as(
355 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
356 FROM memories WHERE memory_type = ? \
357 ORDER BY created_at DESC, id DESC LIMIT ?",
358 )
359 .bind(mt.as_str())
360 .bind(fetch)
361 .fetch_all(self.pool)
362 .await?,
363 (Some(cursor), None) => sqlx::query_as(
364 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
365 FROM memories \
366 WHERE (created_at, id) < \
367 (SELECT created_at, id FROM memories WHERE id = ?) \
368 ORDER BY created_at DESC, id DESC LIMIT ?",
369 )
370 .bind(cursor)
371 .bind(fetch)
372 .fetch_all(self.pool)
373 .await?,
374 (Some(cursor), Some(mt)) => sqlx::query_as(
375 "SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
376 FROM memories \
377 WHERE memory_type = ? \
378 AND (created_at, id) < \
379 (SELECT created_at, id FROM memories WHERE id = ?) \
380 ORDER BY created_at DESC, id DESC LIMIT ?",
381 )
382 .bind(mt.as_str())
383 .bind(cursor)
384 .bind(fetch)
385 .fetch_all(self.pool)
386 .await?,
387 };
388
389 let has_more = rows.len() as i64 > limit;
390 let page_rows = if has_more {
391 &rows[..limit as usize]
392 } else {
393 rows.as_slice()
394 };
395
396 let items = page_rows
397 .iter()
398 .cloned()
399 .map(Self::row_to_record)
400 .collect::<ClawResult<Vec<_>>>()?;
401
402 let next_cursor = if has_more {
403 items.last().map(|item| item.id.to_string())
404 } else {
405 None
406 };
407
408 Ok(ListPage { items, next_cursor })
409 }
410
411 pub async fn fts_search(&self, query: &str) -> ClawResult<Vec<MemoryRecord>> {
413 #[allow(clippy::type_complexity)]
414 let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
415 sqlx::query_as(
416 "SELECT m.id, m.content, m.memory_type, m.tags, m.ttl_seconds, m.created_at, m.updated_at \
417 FROM memories_fts f \
418 JOIN memories m ON m.rowid = f.rowid \
419 WHERE memories_fts MATCH ? \
420 ORDER BY m.created_at DESC, m.id DESC",
421 )
422 .bind(query)
423 .fetch_all(self.pool)
424 .await?;
425
426 rows.into_iter().map(Self::row_to_record).collect()
427 }
428
429 pub async fn expire_ttl(&self) -> ClawResult<u64> {
431 let rows: Vec<(String, Option<i64>, String)> = sqlx::query_as(
432 "SELECT id, ttl_seconds, created_at FROM memories WHERE ttl_seconds IS NOT NULL",
433 )
434 .fetch_all(self.pool)
435 .await?;
436
437 let now = Utc::now();
438 let mut deleted = 0u64;
439
440 for (id_str, ttl_secs, created_at_str) in rows {
441 if let Some(ttl) = ttl_secs {
442 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
443 .map_err(|e| ClawError::Store(e.to_string()))?
444 .with_timezone(&Utc);
445 let expiry = created_at + chrono::Duration::seconds(ttl);
446 if now >= expiry {
447 sqlx::query("DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)")
448 .bind(&id_str)
449 .execute(self.pool)
450 .await?;
451 sqlx::query("DELETE FROM memories WHERE id = ?")
452 .bind(&id_str)
453 .execute(self.pool)
454 .await?;
455 deleted += 1;
456 }
457 }
458 }
459
460 Ok(deleted)
461 }
462
463 fn row_to_record(
464 row: (String, String, String, String, Option<i64>, String, String),
465 ) -> ClawResult<MemoryRecord> {
466 let (id_str, content, memory_type_str, tags_str, ttl_secs, created_at_str, updated_at_str) =
467 row;
468 Ok(MemoryRecord {
469 id: Uuid::parse_str(&id_str).map_err(|e| ClawError::Store(e.to_string()))?,
470 content,
471 memory_type: memory_type_str.parse()?,
472 tags: serde_json::from_str(&tags_str)?,
473 ttl_seconds: ttl_secs.map(|s| s as u64),
474 created_at: DateTime::parse_from_rfc3339(&created_at_str)
475 .map_err(|e| ClawError::Store(e.to_string()))?
476 .with_timezone(&Utc),
477 updated_at: DateTime::parse_from_rfc3339(&updated_at_str)
478 .map_err(|e| ClawError::Store(e.to_string()))?
479 .with_timezone(&Utc),
480 })
481 }
482}