1use std::sync::Arc;
4
5use chrono::Utc;
6use serde::{Deserialize, Serialize};
7use sqlx::{Sqlite, Transaction};
8use tokio::sync::Mutex;
9use uuid::Uuid;
10
11use crate::cache::ClawCache;
12use crate::engine::ClawEngine;
13use crate::error::{ClawError, ClawResult};
14use crate::store::memory::MemoryRecord;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CacheRecord {
19 pub agent_id: String,
21 pub key: String,
23 pub value: serde_json::Value,
25}
26
27#[derive(Debug, Clone)]
29pub enum CacheOp {
30 Insert(Uuid, CacheRecord),
32 Delete(Uuid),
34 Clear,
36}
37
38#[derive(Debug, Clone)]
40pub enum StagedMemoryOp {
41 Insert(MemoryRecord),
43 Update {
45 id: Uuid,
47 content: String,
49 },
50 Delete {
52 id: Uuid,
54 },
55}
56
57pub struct ClawTransaction<'c> {
59 inner: Transaction<'c, Sqlite>,
60 staged: Vec<StagedMemoryOp>,
62 cache_ops: Vec<CacheOp>,
63 cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
64}
65
66impl<'c> ClawTransaction<'c> {
67 pub(crate) fn new(
69 inner: Transaction<'c, Sqlite>,
70 cache: Arc<Mutex<ClawCache<Uuid, MemoryRecord>>>,
71 ) -> Self {
72 ClawTransaction {
73 inner,
74 staged: Vec::new(),
75 cache_ops: Vec::new(),
76 cache,
77 }
78 }
79
80 pub async fn begin(engine: &'c ClawEngine) -> ClawResult<ClawTransaction<'c>> {
82 let tx = engine
83 .pool
84 .begin()
85 .await
86 .map_err(|e| ClawError::Transaction(e.to_string()))?;
87 Ok(ClawTransaction::new(tx, Arc::clone(&engine.cache)))
88 }
89
90 pub fn stage(&mut self, op: CacheOp) {
92 self.cache_ops.push(op);
93 }
94
95 pub fn pending_cache_ops(&self) -> &[CacheOp] {
97 &self.cache_ops
98 }
99
100 pub async fn insert_memory(&mut self, record: &MemoryRecord) -> ClawResult<Uuid> {
102 self.staged.push(StagedMemoryOp::Insert(record.clone()));
103 Ok(record.id)
104 }
105
106 pub fn update_memory(&mut self, id: Uuid, content: impl Into<String>) {
108 self.staged.push(StagedMemoryOp::Update {
109 id,
110 content: content.into(),
111 });
112 }
113
114 pub fn delete_memory(&mut self, id: Uuid) {
116 self.staged.push(StagedMemoryOp::Delete { id });
117 }
118
119 pub async fn commit(mut self) -> ClawResult<()> {
121 let mut committed_records = Vec::new();
122
123 for op in std::mem::take(&mut self.staged) {
124 match op {
125 StagedMemoryOp::Insert(record) => {
126 let tags =
127 serde_json::to_string(&record.tags).map_err(ClawError::Serialization)?;
128 sqlx::query(
129 "INSERT INTO memories \
130 (id, content, memory_type, tags, ttl_seconds, created_at, updated_at) \
131 VALUES (?, ?, ?, ?, ?, ?, ?)",
132 )
133 .bind(record.id.to_string())
134 .bind(&record.content)
135 .bind(record.memory_type.as_str())
136 .bind(&tags)
137 .bind(record.ttl_seconds.map(|s| s as i64))
138 .bind(record.created_at.to_rfc3339())
139 .bind(record.updated_at.to_rfc3339())
140 .execute(&mut *self.inner)
141 .await?;
142
143 sqlx::query(
144 "INSERT INTO memories_fts(rowid, content) \
145 VALUES (last_insert_rowid(), ?)",
146 )
147 .bind(&record.content)
148 .execute(&mut *self.inner)
149 .await?;
150
151 for tag in &record.tags {
152 sqlx::query(
153 "INSERT OR REPLACE INTO memory_tags(memory_id, tag) VALUES (?, ?)",
154 )
155 .bind(record.id.to_string())
156 .bind(tag)
157 .execute(&mut *self.inner)
158 .await?;
159 }
160
161 committed_records.push(record);
162 }
163 StagedMemoryOp::Update { id, content } => {
164 let updated_at = Utc::now().to_rfc3339();
165 let affected =
166 sqlx::query("UPDATE memories SET content = ?, updated_at = ? WHERE id = ?")
167 .bind(&content)
168 .bind(updated_at)
169 .bind(id.to_string())
170 .execute(&mut *self.inner)
171 .await?
172 .rows_affected();
173
174 if affected == 0 {
175 return Err(ClawError::NotFound {
176 entity: "MemoryRecord".to_string(),
177 id: id.to_string(),
178 });
179 }
180
181 sqlx::query(
182 "DELETE FROM memories_fts WHERE rowid = \
183 (SELECT rowid FROM memories WHERE id = ?)",
184 )
185 .bind(id.to_string())
186 .execute(&mut *self.inner)
187 .await?;
188
189 sqlx::query(
190 "INSERT INTO memories_fts(rowid, content) \
191 VALUES ((SELECT rowid FROM memories WHERE id = ?), ?)",
192 )
193 .bind(id.to_string())
194 .bind(content)
195 .execute(&mut *self.inner)
196 .await?;
197 }
198 StagedMemoryOp::Delete { id } => {
199 sqlx::query(
200 "DELETE FROM memories_fts WHERE rowid = \
201 (SELECT rowid FROM memories WHERE id = ?)",
202 )
203 .bind(id.to_string())
204 .execute(&mut *self.inner)
205 .await?;
206
207 let affected = sqlx::query("DELETE FROM memories WHERE id = ?")
208 .bind(id.to_string())
209 .execute(&mut *self.inner)
210 .await?
211 .rows_affected();
212
213 if affected == 0 {
214 return Err(ClawError::NotFound {
215 entity: "MemoryRecord".to_string(),
216 id: id.to_string(),
217 });
218 }
219 }
220 }
221 }
222
223 self.inner
224 .commit()
225 .await
226 .map_err(|e| ClawError::Transaction(e.to_string()))?;
227
228 if !committed_records.is_empty() {
229 let mut cache = self.cache.lock().await;
230 for record in committed_records {
231 cache.insert(record.id, record);
232 }
233 }
234
235 Ok(())
236 }
237
238 pub async fn rollback(mut self) -> ClawResult<()> {
240 self.staged.clear();
241 self.inner
242 .rollback()
243 .await
244 .map_err(|e| ClawError::Transaction(e.to_string()))
245 }
246
247 pub fn inner_mut(&mut self) -> &mut Transaction<'c, Sqlite> {
249 &mut self.inner
250 }
251}
252
253impl<'c> std::fmt::Debug for ClawTransaction<'c> {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("ClawTransaction").finish_non_exhaustive()
256 }
257}