1use crate::cache::TableCache;
6use crate::journal::{Journal, JournalEntry};
7use alloc::vec::Vec;
8use cynos_core::{Error, Result, Row, RowId};
9use core::sync::atomic::{AtomicU64, Ordering};
10
11static NEXT_TX_ID: AtomicU64 = AtomicU64::new(1);
13
14pub type TransactionId = u64;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum TransactionState {
20 Active,
22 Committed,
24 RolledBack,
26}
27
28pub struct Transaction {
30 id: TransactionId,
32 journal: Journal,
34 state: TransactionState,
36}
37
38impl Transaction {
39 pub fn begin() -> Self {
41 Self {
42 id: NEXT_TX_ID.fetch_add(1, Ordering::SeqCst),
43 journal: Journal::new(),
44 state: TransactionState::Active,
45 }
46 }
47
48 pub fn id(&self) -> TransactionId {
50 self.id
51 }
52
53 pub fn state(&self) -> TransactionState {
55 self.state
56 }
57
58 pub fn is_active(&self) -> bool {
60 self.state == TransactionState::Active
61 }
62
63 fn check_active(&self) -> Result<()> {
65 if self.state != TransactionState::Active {
66 return Err(Error::invalid_operation("Transaction is not active"));
67 }
68 Ok(())
69 }
70
71 pub fn insert(&mut self, cache: &mut TableCache, table: &str, row: Row) -> Result<RowId> {
73 self.check_active()?;
74
75 let store = cache.get_table_mut(table).ok_or_else(|| Error::table_not_found(table))?;
76 let row_id = store.insert(row.clone())?;
77
78 self.journal.record_insert(table, row);
79 Ok(row_id)
80 }
81
82 pub fn update(&mut self, cache: &mut TableCache, table: &str, row_id: RowId, new_row: Row) -> Result<()> {
84 self.check_active()?;
85
86 let store = cache.get_table_mut(table).ok_or_else(|| Error::table_not_found(table))?;
87 let old_row = store.get(row_id).ok_or_else(|| {
88 Error::not_found(table, cynos_core::Value::Int64(row_id as i64))
89 })?;
90
91 let old_row_owned = (*old_row).clone();
92 store.update(row_id, new_row.clone())?;
93 self.journal.record_update(table, old_row_owned, new_row);
94 Ok(())
95 }
96
97 pub fn delete(&mut self, cache: &mut TableCache, table: &str, row_id: RowId) -> Result<Row> {
99 self.check_active()?;
100
101 let store = cache.get_table_mut(table).ok_or_else(|| Error::table_not_found(table))?;
102 let row = store.delete(row_id)?;
103
104 let row_owned = (*row).clone();
105 self.journal.record_delete(table, row_owned.clone());
106 Ok(row_owned)
107 }
108
109 pub fn commit(mut self) -> Result<Vec<JournalEntry>> {
111 self.check_active()?;
112 self.state = TransactionState::Committed;
113 Ok(self.journal.commit())
114 }
115
116 pub fn rollback(mut self, cache: &mut TableCache) -> Result<()> {
118 self.check_active()?;
119 self.state = TransactionState::RolledBack;
120 self.journal.rollback(cache)
121 }
122
123 pub fn get_changes(&self) -> &[JournalEntry] {
125 self.journal.get_entries()
126 }
127
128 pub fn journal(&self) -> &Journal {
130 &self.journal
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use cynos_core::schema::TableBuilder;
138 use cynos_core::{DataType, Value};
139 use alloc::format;
140 use alloc::vec;
141
142 fn test_schema() -> cynos_core::schema::Table {
143 TableBuilder::new("test")
144 .unwrap()
145 .add_column("id", DataType::Int64)
146 .unwrap()
147 .add_column("name", DataType::String)
148 .unwrap()
149 .add_primary_key(&["id"], false)
150 .unwrap()
151 .build()
152 .unwrap()
153 }
154
155 #[test]
156 fn test_transaction_begin() {
157 let tx = Transaction::begin();
158 assert!(tx.is_active());
159 assert!(tx.id() > 0);
160 }
161
162 #[test]
163 fn test_transaction_insert_commit() {
164 let mut cache = TableCache::new();
165 cache.create_table(test_schema()).unwrap();
166
167 let mut tx = Transaction::begin();
168 let row = Row::new(1, vec![Value::Int64(1), Value::String("test".into())]);
169 tx.insert(&mut cache, "test", row).unwrap();
170
171 let entries = tx.commit().unwrap();
172 assert_eq!(entries.len(), 1);
173 assert_eq!(cache.get_table("test").unwrap().len(), 1);
174 }
175
176 #[test]
177 fn test_transaction_rollback() {
178 let mut cache = TableCache::new();
179 cache.create_table(test_schema()).unwrap();
180
181 let mut tx = Transaction::begin();
182 let row = Row::new(1, vec![Value::Int64(1), Value::String("test".into())]);
183 tx.insert(&mut cache, "test", row).unwrap();
184
185 assert_eq!(cache.get_table("test").unwrap().len(), 1);
186
187 tx.rollback(&mut cache).unwrap();
188 assert_eq!(cache.get_table("test").unwrap().len(), 0);
189 }
190
191 #[test]
192 fn test_transaction_update() {
193 let mut cache = TableCache::new();
194 cache.create_table(test_schema()).unwrap();
195
196 let row = Row::new(1, vec![Value::Int64(1), Value::String("initial".into())]);
198 cache.get_table_mut("test").unwrap().insert(row).unwrap();
199
200 let mut tx = Transaction::begin();
202 let new_row = Row::new(1, vec![Value::Int64(1), Value::String("updated".into())]);
203 tx.update(&mut cache, "test", 1, new_row).unwrap();
204
205 let entries = tx.commit().unwrap();
206 assert_eq!(entries.len(), 1);
207
208 let stored = cache.get_table("test").unwrap().get(1).unwrap();
209 assert_eq!(stored.get(1), Some(&Value::String("updated".into())));
210 }
211
212 #[test]
213 fn test_transaction_delete() {
214 let mut cache = TableCache::new();
215 cache.create_table(test_schema()).unwrap();
216
217 let row = Row::new(1, vec![Value::Int64(1), Value::String("test".into())]);
219 cache.get_table_mut("test").unwrap().insert(row).unwrap();
220
221 let mut tx = Transaction::begin();
223 tx.delete(&mut cache, "test", 1).unwrap();
224
225 let entries = tx.commit().unwrap();
226 assert_eq!(entries.len(), 1);
227 assert_eq!(cache.get_table("test").unwrap().len(), 0);
228 }
229
230 #[test]
231 fn test_transaction_state_after_commit() {
232 let mut cache = TableCache::new();
233 cache.create_table(test_schema()).unwrap();
234
235 let tx = Transaction::begin();
236 let _ = tx.commit();
237 }
239
240 #[test]
241 fn test_multiple_operations() {
242 let mut cache = TableCache::new();
243 cache.create_table(test_schema()).unwrap();
244
245 let mut tx = Transaction::begin();
246
247 for i in 1..=3 {
249 let row = Row::new(i, vec![Value::Int64(i as i64), Value::String(format!("row{}", i))]);
250 tx.insert(&mut cache, "test", row).unwrap();
251 }
252
253 let updated = Row::new(2, vec![Value::Int64(2), Value::String("updated".into())]);
255 tx.update(&mut cache, "test", 2, updated).unwrap();
256
257 tx.delete(&mut cache, "test", 3).unwrap();
259
260 let entries = tx.commit().unwrap();
261 assert_eq!(entries.len(), 5); assert_eq!(cache.get_table("test").unwrap().len(), 2);
263 }
264
265 #[test]
266 fn test_transaction_rollback_update() {
267 let mut cache = TableCache::new();
268 cache.create_table(test_schema()).unwrap();
269
270 let row = Row::new(1, vec![Value::Int64(1), Value::String("original".into())]);
272 cache.get_table_mut("test").unwrap().insert(row).unwrap();
273
274 let mut tx = Transaction::begin();
276 let new_row = Row::new(1, vec![Value::Int64(1), Value::String("modified".into())]);
277 tx.update(&mut cache, "test", 1, new_row).unwrap();
278
279 assert_eq!(
281 cache.get_table("test").unwrap().get(1).unwrap().get(1),
282 Some(&Value::String("modified".into()))
283 );
284
285 tx.rollback(&mut cache).unwrap();
287
288 assert_eq!(
290 cache.get_table("test").unwrap().get(1).unwrap().get(1),
291 Some(&Value::String("original".into()))
292 );
293 }
294
295 #[test]
296 fn test_transaction_rollback_delete() {
297 let mut cache = TableCache::new();
298 cache.create_table(test_schema()).unwrap();
299
300 let row = Row::new(1, vec![Value::Int64(1), Value::String("test".into())]);
302 cache.get_table_mut("test").unwrap().insert(row).unwrap();
303
304 let mut tx = Transaction::begin();
306 tx.delete(&mut cache, "test", 1).unwrap();
307
308 assert!(cache.get_table("test").unwrap().get(1).is_none());
310
311 tx.rollback(&mut cache).unwrap();
313
314 assert!(cache.get_table("test").unwrap().get(1).is_some());
316 }
317
318 #[test]
319 fn test_transaction_complex_rollback() {
320 let mut cache = TableCache::new();
321 cache.create_table(test_schema()).unwrap();
322
323 let row1 = Row::new(1, vec![Value::Int64(1), Value::String("row1".into())]);
325 let row2 = Row::new(2, vec![Value::Int64(2), Value::String("row2".into())]);
326 cache.get_table_mut("test").unwrap().insert(row1).unwrap();
327 cache.get_table_mut("test").unwrap().insert(row2).unwrap();
328
329 let mut tx = Transaction::begin();
331
332 let row3 = Row::new(3, vec![Value::Int64(3), Value::String("row3".into())]);
334 tx.insert(&mut cache, "test", row3).unwrap();
335
336 let updated_row1 = Row::new(1, vec![Value::Int64(1), Value::String("updated".into())]);
338 tx.update(&mut cache, "test", 1, updated_row1).unwrap();
339
340 tx.delete(&mut cache, "test", 2).unwrap();
342
343 assert_eq!(cache.get_table("test").unwrap().len(), 2); tx.rollback(&mut cache).unwrap();
348
349 assert_eq!(cache.get_table("test").unwrap().len(), 2); assert_eq!(
352 cache.get_table("test").unwrap().get(1).unwrap().get(1),
353 Some(&Value::String("row1".into()))
354 );
355 assert!(cache.get_table("test").unwrap().get(2).is_some());
356 assert!(cache.get_table("test").unwrap().get(3).is_none());
357 }
358
359 #[test]
360 fn test_transaction_error_on_inactive() {
361 let mut cache = TableCache::new();
362 cache.create_table(test_schema()).unwrap();
363
364 let tx = Transaction::begin();
365 let _ = tx.commit();
366
367 }
370
371 #[test]
372 fn test_transaction_journal_entries() {
373 let mut cache = TableCache::new();
374 cache.create_table(test_schema()).unwrap();
375
376 let mut tx = Transaction::begin();
377
378 let row = Row::new(1, vec![Value::Int64(1), Value::String("test".into())]);
379 tx.insert(&mut cache, "test", row).unwrap();
380
381 let changes = tx.get_changes();
383 assert_eq!(changes.len(), 1);
384 assert!(matches!(changes[0], JournalEntry::Insert { .. }));
385
386 tx.commit().unwrap();
387 }
388}