use serde::{Deserialize, Serialize};
use sqlx::{Sqlite, Transaction};
use uuid::Uuid;
use crate::engine::ClawEngine;
use crate::error::{ClawError, ClawResult};
use crate::store::memory::MemoryRecord;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheRecord {
pub agent_id: String,
pub key: String,
pub value: serde_json::Value,
}
#[derive(Debug, Clone)]
pub enum CacheOp {
Insert(Uuid, CacheRecord),
Delete(Uuid),
Clear,
}
#[derive(Debug, Clone)]
pub enum StagedOp {
InsertMemory(MemoryRecord),
}
pub struct ClawTransaction<'c> {
inner: Transaction<'c, Sqlite>,
staged: Vec<StagedOp>,
cache_ops: Vec<CacheOp>,
}
impl<'c> ClawTransaction<'c> {
pub(crate) fn new(inner: Transaction<'c, Sqlite>) -> Self {
ClawTransaction {
inner,
staged: Vec::new(),
cache_ops: Vec::new(),
}
}
pub async fn begin(engine: &'c ClawEngine) -> ClawResult<ClawTransaction<'c>> {
let tx = engine
.pool
.begin()
.await
.map_err(|e| ClawError::Transaction(e.to_string()))?;
Ok(ClawTransaction::new(tx))
}
pub fn stage(&mut self, op: CacheOp) {
self.cache_ops.push(op);
}
pub fn pending_cache_ops(&self) -> &[CacheOp] {
&self.cache_ops
}
pub async fn commit(mut self) -> ClawResult<()> {
for op in std::mem::take(&mut self.staged) {
match op {
StagedOp::InsertMemory(record) => {
let tags =
serde_json::to_string(&record.tags).map_err(ClawError::Serialization)?;
sqlx::query(
"INSERT INTO memories \
(id, content, memory_type, tags, ttl_seconds, \
created_at, updated_at) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
)
.bind(record.id.to_string())
.bind(&record.content)
.bind(record.memory_type.as_str())
.bind(&tags)
.bind(record.ttl_seconds.map(|s| s as i64))
.bind(record.created_at.to_rfc3339())
.bind(record.updated_at.to_rfc3339())
.execute(&mut *self.inner)
.await?;
sqlx::query("INSERT INTO memories_fts(id, content) VALUES (?, ?)")
.bind(record.id.to_string())
.bind(&record.content)
.execute(&mut *self.inner)
.await?;
for tag in &record.tags {
sqlx::query(
"INSERT OR IGNORE INTO memory_tags(memory_id, tag) \
VALUES (?, ?)",
)
.bind(record.id.to_string())
.bind(tag)
.execute(&mut *self.inner)
.await?;
}
}
}
}
self.inner
.commit()
.await
.map_err(|e| ClawError::Transaction(e.to_string()))
}
pub async fn rollback(mut self) -> ClawResult<()> {
self.staged.clear();
self.inner
.rollback()
.await
.map_err(|e| ClawError::Transaction(e.to_string()))
}
pub fn inner_mut(&mut self) -> &mut Transaction<'c, Sqlite> {
&mut self.inner
}
pub async fn insert_memory(&mut self, record: &MemoryRecord) -> ClawResult<Uuid> {
self.staged.push(StagedOp::InsertMemory(record.clone()));
Ok(record.id)
}
}
impl<'c> std::fmt::Debug for ClawTransaction<'c> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClawTransaction").finish_non_exhaustive()
}
}