use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use uuid::Uuid;
use crate::error::{ClawError, ClawResult};
#[derive(Debug, Clone)]
pub struct ListOptions {
pub limit: u32,
pub cursor: Option<String>,
}
impl Default for ListOptions {
fn default() -> Self {
Self {
limit: 50,
cursor: None,
}
}
}
impl ListOptions {
pub fn validated_limit(&self) -> u32 {
self.limit.clamp(1, 1000)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListPage<T> {
pub items: Vec<T>,
pub next_cursor: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryType {
Semantic,
Episodic,
Working,
Procedural,
}
impl MemoryType {
pub fn as_str(&self) -> &'static str {
match self {
MemoryType::Semantic => "semantic",
MemoryType::Episodic => "episodic",
MemoryType::Working => "working",
MemoryType::Procedural => "procedural",
}
}
}
impl std::fmt::Display for MemoryType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for MemoryType {
type Err = ClawError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"semantic" => Ok(MemoryType::Semantic),
"episodic" => Ok(MemoryType::Episodic),
"working" => Ok(MemoryType::Working),
"procedural" => Ok(MemoryType::Procedural),
other => Err(ClawError::InvalidInput(format!(
"unknown memory type: {other}"
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: Uuid,
pub content: String,
pub memory_type: MemoryType,
pub tags: Vec<String>,
pub ttl_seconds: Option<u64>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl MemoryRecord {
pub fn new(
content: impl Into<String>,
memory_type: MemoryType,
tags: Vec<String>,
ttl_seconds: Option<u64>,
) -> Self {
let now = Utc::now();
MemoryRecord {
id: Uuid::new_v4(),
content: content.into(),
memory_type,
tags,
ttl_seconds,
created_at: now,
updated_at: now,
}
}
}
#[derive(Debug)]
pub struct MemoryStore<'a> {
pool: &'a SqlitePool,
}
impl<'a> MemoryStore<'a> {
pub fn new(pool: &'a SqlitePool) -> Self {
MemoryStore { pool }
}
pub async fn insert(&self, record: &MemoryRecord) -> ClawResult<()> {
let tags = serde_json::to_string(&record.tags)?;
let mut tx = self.pool.begin().await?;
sqlx::query(
"INSERT INTO memories \
(id, content, memory_type, tags, ttl_seconds, created_at, updated_at) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
)
.bind(record.id.to_string())
.bind(&record.content)
.bind(record.memory_type.as_str())
.bind(&tags)
.bind(record.ttl_seconds.map(|s| s as i64))
.bind(record.created_at.to_rfc3339())
.bind(record.updated_at.to_rfc3339())
.execute(&mut *tx)
.await?;
sqlx::query(
"INSERT INTO memories_fts(rowid, content) \
VALUES (last_insert_rowid(), ?)",
)
.bind(&record.content)
.execute(&mut *tx)
.await?;
for tag in &record.tags {
sqlx::query("INSERT OR REPLACE INTO memory_tags(memory_id, tag) VALUES (?, ?)")
.bind(record.id.to_string())
.bind(tag)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
pub async fn get(&self, id: Uuid) -> ClawResult<MemoryRecord> {
let row =
sqlx::query_as::<_, (String, String, String, String, Option<i64>, String, String)>(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories WHERE id = ?",
)
.bind(id.to_string())
.fetch_optional(self.pool)
.await?;
let row = row.ok_or_else(|| ClawError::NotFound {
entity: "MemoryRecord".to_string(),
id: id.to_string(),
})?;
Self::row_to_record(row)
}
pub async fn update_content(
&self,
id: Uuid,
content: &str,
updated_at: DateTime<Utc>,
) -> ClawResult<()> {
let mut tx = self.pool.begin().await?;
let affected = sqlx::query("UPDATE memories SET content = ?, updated_at = ? WHERE id = ?")
.bind(content)
.bind(updated_at.to_rfc3339())
.bind(id.to_string())
.execute(&mut *tx)
.await?
.rows_affected();
if affected == 0 {
return Err(ClawError::NotFound {
entity: "MemoryRecord".to_string(),
id: id.to_string(),
});
}
sqlx::query(
"DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)",
)
.bind(id.to_string())
.execute(&mut *tx)
.await?;
sqlx::query(
"INSERT INTO memories_fts(rowid, content) \
VALUES ((SELECT rowid FROM memories WHERE id = ?), ?)",
)
.bind(id.to_string())
.bind(content)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
pub async fn delete(&self, id: Uuid) -> ClawResult<()> {
let mut tx = self.pool.begin().await?;
sqlx::query(
"DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)",
)
.bind(id.to_string())
.execute(&mut *tx)
.await?;
let affected = sqlx::query("DELETE FROM memories WHERE id = ?")
.bind(id.to_string())
.execute(&mut *tx)
.await?
.rows_affected();
if affected == 0 {
return Err(ClawError::NotFound {
entity: "MemoryRecord".to_string(),
id: id.to_string(),
});
}
tx.commit().await?;
Ok(())
}
pub async fn list(&self, type_filter: Option<&MemoryType>) -> ClawResult<Vec<MemoryRecord>> {
#[allow(clippy::type_complexity)]
let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
match type_filter {
Some(mt) => sqlx::query_as(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories WHERE memory_type = ? ORDER BY created_at DESC, id DESC",
)
.bind(mt.as_str())
.fetch_all(self.pool)
.await?,
None => sqlx::query_as(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories ORDER BY created_at DESC, id DESC",
)
.fetch_all(self.pool)
.await?,
};
rows.into_iter().map(Self::row_to_record).collect()
}
pub async fn search_by_tag(
&self,
tag: &str,
limit: u32,
offset: u32,
) -> ClawResult<Vec<MemoryRecord>> {
#[allow(clippy::type_complexity)]
let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
sqlx::query_as(
"SELECT m.id, m.content, m.memory_type, m.tags, m.ttl_seconds, \
m.created_at, m.updated_at \
FROM memories m \
JOIN memory_tags t ON m.id = t.memory_id \
WHERE t.tag = ? \
ORDER BY m.created_at DESC LIMIT ? OFFSET ?",
)
.bind(tag)
.bind(limit as i64)
.bind(offset as i64)
.fetch_all(self.pool)
.await?;
rows.into_iter().map(Self::row_to_record).collect()
}
pub async fn list_paginated(
&self,
type_filter: Option<&MemoryType>,
opts: &ListOptions,
) -> ClawResult<ListPage<MemoryRecord>> {
let limit = opts.validated_limit() as i64;
let fetch = limit.saturating_add(1);
#[allow(clippy::type_complexity)]
let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
match (&opts.cursor, type_filter) {
(None, None) => sqlx::query_as(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories ORDER BY created_at DESC, id DESC LIMIT ?",
)
.bind(fetch)
.fetch_all(self.pool)
.await?,
(None, Some(mt)) => sqlx::query_as(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories WHERE memory_type = ? \
ORDER BY created_at DESC, id DESC LIMIT ?",
)
.bind(mt.as_str())
.bind(fetch)
.fetch_all(self.pool)
.await?,
(Some(cursor), None) => sqlx::query_as(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories \
WHERE (created_at, id) < \
(SELECT created_at, id FROM memories WHERE id = ?) \
ORDER BY created_at DESC, id DESC LIMIT ?",
)
.bind(cursor)
.bind(fetch)
.fetch_all(self.pool)
.await?,
(Some(cursor), Some(mt)) => sqlx::query_as(
"SELECT id, content, memory_type, tags, ttl_seconds, created_at, updated_at \
FROM memories \
WHERE memory_type = ? \
AND (created_at, id) < \
(SELECT created_at, id FROM memories WHERE id = ?) \
ORDER BY created_at DESC, id DESC LIMIT ?",
)
.bind(mt.as_str())
.bind(cursor)
.bind(fetch)
.fetch_all(self.pool)
.await?,
};
let has_more = rows.len() as i64 > limit;
let page_rows = if has_more {
&rows[..limit as usize]
} else {
rows.as_slice()
};
let items = page_rows
.iter()
.cloned()
.map(Self::row_to_record)
.collect::<ClawResult<Vec<_>>>()?;
let next_cursor = if has_more {
items.last().map(|item| item.id.to_string())
} else {
None
};
Ok(ListPage { items, next_cursor })
}
pub async fn fts_search(&self, query: &str) -> ClawResult<Vec<MemoryRecord>> {
#[allow(clippy::type_complexity)]
let rows: Vec<(String, String, String, String, Option<i64>, String, String)> =
sqlx::query_as(
"SELECT m.id, m.content, m.memory_type, m.tags, m.ttl_seconds, m.created_at, m.updated_at \
FROM memories_fts f \
JOIN memories m ON m.rowid = f.rowid \
WHERE memories_fts MATCH ? \
ORDER BY m.created_at DESC, m.id DESC",
)
.bind(query)
.fetch_all(self.pool)
.await?;
rows.into_iter().map(Self::row_to_record).collect()
}
pub async fn expire_ttl(&self) -> ClawResult<u64> {
let rows: Vec<(String, Option<i64>, String)> = sqlx::query_as(
"SELECT id, ttl_seconds, created_at FROM memories WHERE ttl_seconds IS NOT NULL",
)
.fetch_all(self.pool)
.await?;
let now = Utc::now();
let mut deleted = 0u64;
for (id_str, ttl_secs, created_at_str) in rows {
if let Some(ttl) = ttl_secs {
let created_at = DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| ClawError::Store(e.to_string()))?
.with_timezone(&Utc);
let expiry = created_at + chrono::Duration::seconds(ttl);
if now >= expiry {
sqlx::query("DELETE FROM memories_fts WHERE rowid = (SELECT rowid FROM memories WHERE id = ?)")
.bind(&id_str)
.execute(self.pool)
.await?;
sqlx::query("DELETE FROM memories WHERE id = ?")
.bind(&id_str)
.execute(self.pool)
.await?;
deleted += 1;
}
}
}
Ok(deleted)
}
fn row_to_record(
row: (String, String, String, String, Option<i64>, String, String),
) -> ClawResult<MemoryRecord> {
let (id_str, content, memory_type_str, tags_str, ttl_secs, created_at_str, updated_at_str) =
row;
Ok(MemoryRecord {
id: Uuid::parse_str(&id_str).map_err(|e| ClawError::Store(e.to_string()))?,
content,
memory_type: memory_type_str.parse()?,
tags: serde_json::from_str(&tags_str)?,
ttl_seconds: ttl_secs.map(|s| s as u64),
created_at: DateTime::parse_from_rfc3339(&created_at_str)
.map_err(|e| ClawError::Store(e.to_string()))?
.with_timezone(&Utc),
updated_at: DateTime::parse_from_rfc3339(&updated_at_str)
.map_err(|e| ClawError::Store(e.to_string()))?
.with_timezone(&Utc),
})
}
}