Skip to main content

claw_core/
transaction.rs

1//! Transaction wrapper for claw-core.
2
3use std::sync::Arc;
4
5use chrono::Utc;
6use serde::{Deserialize, Serialize};
7use sqlx::{Sqlite, Transaction};
8use tokio::sync::Mutex;
9use uuid::Uuid;
10
11use crate::cache::ClawCache;
12use crate::engine::ClawEngine;
13use crate::error::{ClawError, ClawResult};
14use crate::store::memory::MemoryRecord;
15
16/// A key-value cache record that can be staged for deferred cache operations.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CacheRecord {
19    /// The agent that owns this record.
20    pub agent_id: String,
21    /// Logical key within the agent's scratchpad.
22    pub key: String,
23    /// Serialized JSON value.
24    pub value: serde_json::Value,
25}
26
27/// A deferred cache mutation that should be applied after a successful commit.
28#[derive(Debug, Clone)]
29pub enum CacheOp {
30    /// Insert or update the record identified by the given [`Uuid`].
31    Insert(Uuid, CacheRecord),
32    /// Remove the record identified by the given [`Uuid`] from the cache.
33    Delete(Uuid),
34    /// Invalidate the entire cache.
35    Clear,
36}
37
38/// A database write operation staged in [`ClawTransaction`].
39#[derive(Debug, Clone)]
40pub enum StagedMemoryOp {
41    /// Insert a new [`MemoryRecord`].
42    Insert(MemoryRecord),
43    /// Update memory content by id.
44    Update {
45        /// Target memory id.
46        id: Uuid,
47        /// New content.
48        content: String,
49    },
50    /// Delete memory by id.
51    Delete {
52        /// Target memory id.
53        id: Uuid,
54    },
55}
56
57/// A wrapper around a SQLite transaction with staged memory operations.
58pub struct ClawTransaction<'c> {
59    inner: Transaction<'c, Sqlite>,
60    /// Pending memory operations applied atomically during commit.
61    staged: Vec<StagedMemoryOp>,
62    cache_ops: Vec<CacheOp>,
63    cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
64}
65
66impl<'c> ClawTransaction<'c> {
67    /// Wrap a raw SQLx transaction.
68    pub(crate) fn new(
69        inner: Transaction<'c, Sqlite>,
70        cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
71    ) -> Self {
72        ClawTransaction {
73            inner,
74            staged: Vec::new(),
75            cache_ops: Vec::new(),
76            cache,
77        }
78    }
79
80    /// Begin a new transaction against the engine's connection pool.
81    pub async fn begin(engine: &'c ClawEngine) -> ClawResult<ClawTransaction<'c>> {
82        let tx = engine
83            .pool
84            .begin()
85            .await
86            .map_err(|e| ClawError::Transaction(e.to_string()))?;
87        Ok(ClawTransaction::new(tx, Arc::clone(&engine.cache)))
88    }
89
90    /// Stage a [`CacheOp`] to be applied after successful commit.
91    pub fn stage(&mut self, op: CacheOp) {
92        self.cache_ops.push(op);
93    }
94
95    /// Return a slice of staged cache operations.
96    pub fn pending_cache_ops(&self) -> &[CacheOp] {
97        &self.cache_ops
98    }
99
100    /// Stage insertion of a [`MemoryRecord`].
101    pub async fn insert_memory(&mut self, record: &MemoryRecord) -> ClawResult<Uuid> {
102        self.staged.push(StagedMemoryOp::Insert(record.clone()));
103        Ok(record.id)
104    }
105
106    /// Stage update of memory content.
107    pub fn update_memory(&mut self, id: Uuid, content: impl Into<String>) {
108        self.staged.push(StagedMemoryOp::Update {
109            id,
110            content: content.into(),
111        });
112    }
113
114    /// Stage deletion of a memory row.
115    pub fn delete_memory(&mut self, id: Uuid) {
116        self.staged.push(StagedMemoryOp::Delete { id });
117    }
118
119    /// Commit all staged operations atomically.
120    pub async fn commit(mut self) -> ClawResult<()> {
121        let mut committed_records = Vec::new();
122
123        for op in std::mem::take(&mut self.staged) {
124            match op {
125                StagedMemoryOp::Insert(record) => {
126                    let tags =
127                        serde_json::to_string(&record.tags).map_err(ClawError::Serialization)?;
128                    sqlx::query(
129                        "INSERT INTO memories \
130                         (id, content, memory_type, tags, ttl_seconds, created_at, updated_at) \
131                         VALUES (?, ?, ?, ?, ?, ?, ?)",
132                    )
133                    .bind(record.id.to_string())
134                    .bind(&record.content)
135                    .bind(record.memory_type.as_str())
136                    .bind(&tags)
137                    .bind(record.ttl_seconds.map(|s| s as i64))
138                    .bind(record.created_at.to_rfc3339())
139                    .bind(record.updated_at.to_rfc3339())
140                    .execute(&mut *self.inner)
141                    .await?;
142
143                    sqlx::query(
144                        "INSERT INTO memories_fts(rowid, content) \
145                        VALUES (last_insert_rowid(), ?)",
146                    )
147                    .bind(&record.content)
148                    .execute(&mut *self.inner)
149                    .await?;
150
151                    for tag in &record.tags {
152                        sqlx::query(
153                            "INSERT OR REPLACE INTO memory_tags(memory_id, tag) VALUES (?, ?)",
154                        )
155                        .bind(record.id.to_string())
156                        .bind(tag)
157                        .execute(&mut *self.inner)
158                        .await?;
159                    }
160
161                    committed_records.push(record);
162                }
163                StagedMemoryOp::Update { id, content } => {
164                    let updated_at = Utc::now().to_rfc3339();
165                    let affected =
166                        sqlx::query("UPDATE memories SET content = ?, updated_at = ? WHERE id = ?")
167                            .bind(&content)
168                            .bind(updated_at)
169                            .bind(id.to_string())
170                            .execute(&mut *self.inner)
171                            .await?
172                            .rows_affected();
173
174                    if affected == 0 {
175                        return Err(ClawError::NotFound {
176                            entity: "MemoryRecord".to_string(),
177                            id: id.to_string(),
178                        });
179                    }
180
181                    sqlx::query(
182                        "DELETE FROM memories_fts WHERE rowid = \
183                         (SELECT rowid FROM memories WHERE id = ?)",
184                    )
185                    .bind(id.to_string())
186                    .execute(&mut *self.inner)
187                    .await?;
188
189                    sqlx::query(
190                        "INSERT INTO memories_fts(rowid, content) \
191                         VALUES ((SELECT rowid FROM memories WHERE id = ?), ?)",
192                    )
193                    .bind(id.to_string())
194                    .bind(content)
195                    .execute(&mut *self.inner)
196                    .await?;
197                }
198                StagedMemoryOp::Delete { id } => {
199                    sqlx::query(
200                        "DELETE FROM memories_fts WHERE rowid = \
201                         (SELECT rowid FROM memories WHERE id = ?)",
202                    )
203                    .bind(id.to_string())
204                    .execute(&mut *self.inner)
205                    .await?;
206
207                    let affected = sqlx::query("DELETE FROM memories WHERE id = ?")
208                        .bind(id.to_string())
209                        .execute(&mut *self.inner)
210                        .await?
211                        .rows_affected();
212
213                    if affected == 0 {
214                        return Err(ClawError::NotFound {
215                            entity: "MemoryRecord".to_string(),
216                            id: id.to_string(),
217                        });
218                    }
219                }
220            }
221        }
222
223        self.inner
224            .commit()
225            .await
226            .map_err(|e| ClawError::Transaction(e.to_string()))?;
227
228        if !committed_records.is_empty() {
229            let mut cache = self.cache.lock().await;
230            for record in committed_records {
231                cache.insert(record.id, record);
232            }
233        }
234
235        Ok(())
236    }
237
238    /// Explicitly roll back the transaction, discarding all staged changes.
239    pub async fn rollback(mut self) -> ClawResult<()> {
240        self.staged.clear();
241        self.inner
242            .rollback()
243            .await
244            .map_err(|e| ClawError::Transaction(e.to_string()))
245    }
246
247    /// Return a mutable reference to the inner SQLx transaction.
248    pub fn inner_mut(&mut self) -> &mut Transaction<'c, Sqlite> {
249        &mut self.inner
250    }
251}
252
253impl<'c> std::fmt::Debug for ClawTransaction<'c> {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        f.debug_struct("ClawTransaction").finish_non_exhaustive()
256    }
257}