1use crate::cache::TableCache;
7use alloc::collections::BTreeMap;
8use alloc::string::String;
9use alloc::vec::Vec;
10use cynos_core::{Result, Row, RowId};
11
12#[derive(Clone, Debug)]
14pub enum JournalEntry {
15 Insert {
17 table: String,
18 row_id: RowId,
19 row: Row,
20 },
21 Update {
23 table: String,
24 row_id: RowId,
25 old: Row,
26 new: Row,
27 },
28 Delete {
30 table: String,
31 row_id: RowId,
32 row: Row,
33 },
34}
35
36impl JournalEntry {
37 pub fn table(&self) -> &str {
39 match self {
40 JournalEntry::Insert { table, .. } => table,
41 JournalEntry::Update { table, .. } => table,
42 JournalEntry::Delete { table, .. } => table,
43 }
44 }
45
46 pub fn row_id(&self) -> RowId {
48 match self {
49 JournalEntry::Insert { row_id, .. } => *row_id,
50 JournalEntry::Update { row_id, .. } => *row_id,
51 JournalEntry::Delete { row_id, .. } => *row_id,
52 }
53 }
54}
55
56#[derive(Clone, Debug, Default)]
58pub struct TableDiff {
59 table_name: String,
61 added: BTreeMap<RowId, Row>,
63 modified: BTreeMap<RowId, (Row, Row)>,
65 deleted: BTreeMap<RowId, Row>,
67}
68
69impl TableDiff {
70 pub fn new(table_name: impl Into<String>) -> Self {
72 Self {
73 table_name: table_name.into(),
74 added: BTreeMap::new(),
75 modified: BTreeMap::new(),
76 deleted: BTreeMap::new(),
77 }
78 }
79
80 pub fn table_name(&self) -> &str {
82 &self.table_name
83 }
84
85 pub fn add(&mut self, row: Row) {
87 let row_id = row.id();
88 if let Some(old_row) = self.deleted.remove(&row_id) {
90 self.modified.insert(row_id, (old_row, row));
91 } else {
92 self.added.insert(row_id, row);
93 }
94 }
95
96 pub fn modify(&mut self, old: Row, new: Row) {
98 let row_id = old.id();
99 if self.added.contains_key(&row_id) {
101 self.added.insert(row_id, new);
102 } else if let Some((original_old, _)) = self.modified.get(&row_id) {
103 let original = original_old.clone();
105 self.modified.insert(row_id, (original, new));
106 } else {
107 self.modified.insert(row_id, (old, new));
108 }
109 }
110
111 pub fn delete(&mut self, row: Row) {
113 let row_id = row.id();
114 if self.added.remove(&row_id).is_some() {
116 return;
117 }
118 if let Some((old_row, _)) = self.modified.remove(&row_id) {
120 self.deleted.insert(row_id, old_row);
121 } else {
122 self.deleted.insert(row_id, row);
123 }
124 }
125
126 pub fn get_added(&self) -> &BTreeMap<RowId, Row> {
128 &self.added
129 }
130
131 pub fn get_modified(&self) -> &BTreeMap<RowId, (Row, Row)> {
133 &self.modified
134 }
135
136 pub fn get_deleted(&self) -> &BTreeMap<RowId, Row> {
138 &self.deleted
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.added.is_empty() && self.modified.is_empty() && self.deleted.is_empty()
144 }
145
146 pub fn get_reverse(&self) -> Self {
148 let mut reverse = Self::new(&self.table_name);
149
150 for (row_id, row) in &self.added {
152 reverse.deleted.insert(*row_id, row.clone());
153 }
154
155 for (row_id, (old, new)) in &self.modified {
157 reverse.modified.insert(*row_id, (new.clone(), old.clone()));
158 }
159
160 for (row_id, row) in &self.deleted {
162 reverse.added.insert(*row_id, row.clone());
163 }
164
165 reverse
166 }
167
168 pub fn get_as_modifications(&self) -> Vec<(Option<Row>, Option<Row>)> {
170 let mut mods = Vec::new();
171
172 for row in self.added.values() {
173 mods.push((None, Some(row.clone())));
174 }
175
176 for (old, new) in self.modified.values() {
177 mods.push((Some(old.clone()), Some(new.clone())));
178 }
179
180 for row in self.deleted.values() {
181 mods.push((Some(row.clone()), None));
182 }
183
184 mods
185 }
186}
187
188pub struct Journal {
190 table_diffs: BTreeMap<String, TableDiff>,
192 entries: Vec<JournalEntry>,
194}
195
196impl Journal {
197 pub fn new() -> Self {
199 Self {
200 table_diffs: BTreeMap::new(),
201 entries: Vec::new(),
202 }
203 }
204
205 pub fn record_insert(&mut self, table: &str, row: Row) {
207 let row_id = row.id();
208
209 self.get_or_create_diff(table).add(row.clone());
210
211 self.entries.push(JournalEntry::Insert {
212 table: table.into(),
213 row_id,
214 row,
215 });
216 }
217
218 pub fn record_update(&mut self, table: &str, old: Row, new: Row) {
220 let row_id = old.id();
221
222 self.get_or_create_diff(table).modify(old.clone(), new.clone());
223
224 self.entries.push(JournalEntry::Update {
225 table: table.into(),
226 row_id,
227 old,
228 new,
229 });
230 }
231
232 pub fn record_delete(&mut self, table: &str, row: Row) {
234 let row_id = row.id();
235
236 self.get_or_create_diff(table).delete(row.clone());
237
238 self.entries.push(JournalEntry::Delete {
239 table: table.into(),
240 row_id,
241 row,
242 });
243 }
244
245 fn get_or_create_diff(&mut self, table: &str) -> &mut TableDiff {
247 if !self.table_diffs.contains_key(table) {
248 self.table_diffs.insert(table.into(), TableDiff::new(table));
249 }
250 self.table_diffs.get_mut(table).unwrap()
251 }
252
253 pub fn get_entries(&self) -> &[JournalEntry] {
255 &self.entries
256 }
257
258 pub fn get_table_diff(&self, table: &str) -> Option<&TableDiff> {
260 self.table_diffs.get(table)
261 }
262
263 pub fn get_all_diffs(&self) -> &BTreeMap<String, TableDiff> {
265 &self.table_diffs
266 }
267
268 pub fn is_empty(&self) -> bool {
270 self.entries.is_empty()
271 }
272
273 pub fn commit(&mut self) -> Vec<JournalEntry> {
277 let entries = core::mem::take(&mut self.entries);
278 self.table_diffs.clear();
279 entries
280 }
281
282 pub fn rollback(&mut self, cache: &mut TableCache) -> Result<()> {
284 for entry in self.entries.iter().rev() {
286 match entry {
287 JournalEntry::Insert { table, row_id, .. } => {
288 if let Some(store) = cache.get_table_mut(table) {
289 let _ = store.delete(*row_id);
290 }
291 }
292 JournalEntry::Update { table, row_id, old, new } => {
293 if let Some(store) = cache.get_table_mut(table) {
294 let rollback_row = Row::new_with_version(
296 old.id(),
297 new.version().wrapping_add(1),
298 old.values().to_vec(),
299 );
300 let _ = store.update(*row_id, rollback_row);
301 }
302 }
303 JournalEntry::Delete { table, row, .. } => {
304 if let Some(store) = cache.get_table_mut(table) {
305 let _ = store.insert(row.clone());
306 }
307 }
308 }
309 }
310
311 self.entries.clear();
312 self.table_diffs.clear();
313 Ok(())
314 }
315
316 pub fn clear(&mut self) {
318 self.entries.clear();
319 self.table_diffs.clear();
320 }
321}
322
323impl Default for Journal {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use cynos_core::schema::TableBuilder;
333 use cynos_core::{DataType, Value};
334 use alloc::vec;
335
336 fn test_schema() -> cynos_core::schema::Table {
337 TableBuilder::new("test")
338 .unwrap()
339 .add_column("id", DataType::Int64)
340 .unwrap()
341 .add_column("name", DataType::String)
342 .unwrap()
343 .add_primary_key(&["id"], false)
344 .unwrap()
345 .build()
346 .unwrap()
347 }
348
349 #[test]
350 fn test_journal_insert() {
351 let mut cache = TableCache::new();
352 cache.create_table(test_schema()).unwrap();
353
354 let mut journal = Journal::new();
355 let row = Row::new(1, vec![Value::Int64(1), Value::String("test".into())]);
356
357 cache.get_table_mut("test").unwrap().insert(row.clone()).unwrap();
359 journal.record_insert("test", row);
360
361 assert_eq!(journal.get_entries().len(), 1);
362 assert_eq!(cache.get_table("test").unwrap().len(), 1);
363 }
364
365 #[test]
366 fn test_journal_rollback() {
367 let mut cache = TableCache::new();
368 cache.create_table(test_schema()).unwrap();
369
370 let row1 = Row::new(1, vec![Value::Int64(1), Value::String("initial".into())]);
372 cache.get_table_mut("test").unwrap().insert(row1).unwrap();
373 assert_eq!(cache.get_table("test").unwrap().len(), 1);
374
375 let mut journal = Journal::new();
377 let row2 = Row::new(2, vec![Value::Int64(2), Value::String("second".into())]);
378 cache.get_table_mut("test").unwrap().insert(row2.clone()).unwrap();
379 journal.record_insert("test", row2);
380
381 assert_eq!(cache.get_table("test").unwrap().len(), 2);
382
383 journal.rollback(&mut cache).unwrap();
385
386 assert_eq!(cache.get_table("test").unwrap().len(), 1);
388 }
389
390 #[test]
391 fn test_table_diff_add_delete() {
392 let mut diff = TableDiff::new("test");
393
394 let row = Row::new(1, vec![Value::Int64(1)]);
395 diff.add(row.clone());
396 assert_eq!(diff.get_added().len(), 1);
397
398 diff.delete(row);
399 assert!(diff.is_empty());
400 }
401
402 #[test]
403 fn test_table_diff_modify() {
404 let mut diff = TableDiff::new("test");
405
406 let old = Row::new(1, vec![Value::Int64(1)]);
407 let new = Row::new(1, vec![Value::Int64(2)]);
408 diff.modify(old, new);
409
410 assert_eq!(diff.get_modified().len(), 1);
411 }
412
413 #[test]
414 fn test_table_diff_reverse() {
415 let mut diff = TableDiff::new("test");
416
417 let row = Row::new(1, vec![Value::Int64(1)]);
418 diff.add(row);
419
420 let reverse = diff.get_reverse();
421 assert_eq!(reverse.get_deleted().len(), 1);
422 assert!(reverse.get_added().is_empty());
423 }
424
425 #[test]
426 fn test_table_diff_get_as_modifications() {
427 let mut diff = TableDiff::new("test");
428
429 let row1 = Row::new(1, vec![Value::Int64(1)]);
430 let row2_old = Row::new(2, vec![Value::Int64(2)]);
431 let row2_new = Row::new(2, vec![Value::Int64(20)]);
432 let row3 = Row::new(3, vec![Value::Int64(3)]);
433
434 diff.add(row1);
435 diff.modify(row2_old, row2_new);
436 diff.delete(row3);
437
438 let mods = diff.get_as_modifications();
439 assert_eq!(mods.len(), 3);
440 }
441}