1use crate::backup::BackupManager;
4use crate::error::FileError;
5use crate::models::{FileOperation, FileTransaction, TransactionStatus};
6use crate::writer::SafeWriter;
7use chrono::Utc;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::fs;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
16pub struct TransactionManager {
17 transactions: Arc<RwLock<HashMap<Uuid, FileTransaction>>>,
18 writer: SafeWriter,
19 backup_manager: BackupManager,
20}
21
22impl TransactionManager {
23 pub fn new(backup_manager: BackupManager) -> Self {
29 TransactionManager {
30 transactions: Arc::new(RwLock::new(HashMap::new())),
31 writer: SafeWriter::new(),
32 backup_manager,
33 }
34 }
35
36 pub async fn begin_transaction(&self) -> Result<Uuid, FileError> {
42 let tx_id = Uuid::new_v4();
43 let transaction = FileTransaction {
44 id: tx_id,
45 operations: Vec::new(),
46 status: TransactionStatus::Pending,
47 created_at: Utc::now(),
48 completed_at: None,
49 };
50
51 let mut transactions = self.transactions.write().await;
52 transactions.insert(tx_id, transaction);
53
54 Ok(tx_id)
55 }
56
57 pub async fn add_operation(&self, tx_id: Uuid, op: FileOperation) -> Result<(), FileError> {
68 let mut transactions = self.transactions.write().await;
69
70 let transaction = transactions
71 .get_mut(&tx_id)
72 .ok_or_else(|| FileError::TransactionFailed("Transaction not found".to_string()))?;
73
74 if transaction.status != TransactionStatus::Pending {
75 return Err(FileError::TransactionFailed(
76 "Cannot add operations to non-pending transaction".to_string(),
77 ));
78 }
79
80 transaction.operations.push(op);
81 Ok(())
82 }
83
84 pub async fn commit(&self, tx_id: Uuid) -> Result<(), FileError> {
94 let mut transactions = self.transactions.write().await;
95
96 let transaction = transactions
97 .get_mut(&tx_id)
98 .ok_or_else(|| FileError::TransactionFailed("Transaction not found".to_string()))?;
99
100 if transaction.status != TransactionStatus::Pending {
101 return Err(FileError::TransactionFailed(
102 "Transaction is not pending".to_string(),
103 ));
104 }
105
106 let mut pre_transaction_backups: HashMap<std::path::PathBuf, Option<std::path::PathBuf>> =
108 HashMap::new();
109
110 for op in &transaction.operations {
111 if op.path.exists() {
112 let backup_metadata = self.backup_manager.create_backup(&op.path).await?;
113 pre_transaction_backups.insert(op.path.clone(), Some(backup_metadata.backup_path));
114 } else {
115 pre_transaction_backups.insert(op.path.clone(), None);
117 }
118 }
119
120 transaction.operations.iter_mut().for_each(|op| {
122 if let Some(backup_opt) = pre_transaction_backups.get(&op.path) {
123 op.backup_path = backup_opt.clone();
124 }
125 });
126
127 let mut executed_count = 0;
129 for op in &transaction.operations {
130 match self
131 .writer
132 .write(
133 &op.path,
134 op.content.as_ref().unwrap_or(&String::new()),
135 crate::models::ConflictResolution::Overwrite,
136 )
137 .await
138 {
139 Ok(_) => {
140 executed_count += 1;
141 }
142 Err(e) => {
143 self.rollback_operations(
145 &pre_transaction_backups,
146 executed_count,
147 &transaction.operations,
148 )
149 .await?;
150
151 return Err(FileError::TransactionFailed(format!(
152 "Operation failed after {} successful operations: {}",
153 executed_count, e
154 )));
155 }
156 }
157 }
158
159 transaction.status = TransactionStatus::Committed;
161 transaction.completed_at = Some(Utc::now());
162
163 for op in &transaction.operations {
165 let _ = self.backup_manager.enforce_retention_policy(&op.path).await;
166 }
167
168 Ok(())
169 }
170
171 pub async fn rollback(&self, tx_id: Uuid) -> Result<(), FileError> {
181 let mut transactions = self.transactions.write().await;
182
183 let transaction = transactions
184 .get_mut(&tx_id)
185 .ok_or_else(|| FileError::TransactionFailed("Transaction not found".to_string()))?;
186
187 if transaction.status == TransactionStatus::RolledBack {
188 return Err(FileError::TransactionFailed(
189 "Transaction already rolled back".to_string(),
190 ));
191 }
192
193 for op in &transaction.operations {
195 if let Some(backup_path_opt) = &op.backup_path {
196 self.backup_manager
198 .restore_from_backup(backup_path_opt, &op.path)
199 .await?;
200 } else if op.path.exists() {
201 fs::remove_file(&op.path).await.map_err(|e| {
203 FileError::RollbackFailed(format!(
204 "Failed to delete file during rollback: {}",
205 e
206 ))
207 })?;
208 }
209 }
210
211 transaction.status = TransactionStatus::RolledBack;
213 transaction.completed_at = Some(Utc::now());
214
215 Ok(())
216 }
217
218 pub async fn get_status(&self, tx_id: Uuid) -> Result<TransactionStatus, FileError> {
228 let transactions = self.transactions.read().await;
229
230 transactions
231 .get(&tx_id)
232 .map(|t| t.status)
233 .ok_or_else(|| FileError::TransactionFailed("Transaction not found".to_string()))
234 }
235
236 pub async fn get_transaction(&self, tx_id: Uuid) -> Result<FileTransaction, FileError> {
246 let transactions = self.transactions.read().await;
247
248 transactions
249 .get(&tx_id)
250 .cloned()
251 .ok_or_else(|| FileError::TransactionFailed("Transaction not found".to_string()))
252 }
253
254 async fn rollback_operations(
256 &self,
257 backups: &HashMap<std::path::PathBuf, Option<std::path::PathBuf>>,
258 executed_count: usize,
259 operations: &[FileOperation],
260 ) -> Result<(), FileError> {
261 for op in operations.iter().take(executed_count) {
263 if let Some(backup_opt) = backups.get(&op.path) {
264 if let Some(backup_path) = backup_opt {
265 self.backup_manager
266 .restore_from_backup(backup_path, &op.path)
267 .await
268 .map_err(|e| {
269 FileError::RollbackFailed(format!(
270 "Failed to restore backup during rollback: {}",
271 e
272 ))
273 })?;
274 } else if op.path.exists() {
275 fs::remove_file(&op.path).await.map_err(|e| {
277 FileError::RollbackFailed(format!(
278 "Failed to delete file during rollback: {}",
279 e
280 ))
281 })?;
282 }
283 }
284 }
285
286 Ok(())
287 }
288}
289
290impl Default for TransactionManager {
291 fn default() -> Self {
292 Self::new(BackupManager::default())
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use tempfile::TempDir;
300
301 #[tokio::test]
302 async fn test_begin_transaction() {
303 let manager = TransactionManager::default();
304 let tx_id = manager.begin_transaction().await.unwrap();
305
306 let status = manager.get_status(tx_id).await.unwrap();
307 assert_eq!(status, TransactionStatus::Pending);
308 }
309
310 #[tokio::test]
311 async fn test_add_operation() {
312 let manager = TransactionManager::default();
313 let tx_id = manager.begin_transaction().await.unwrap();
314
315 let op = FileOperation {
316 path: std::path::PathBuf::from("test.txt"),
317 operation: crate::models::OperationType::Create,
318 content: Some("test content".to_string()),
319 backup_path: None,
320 content_hash: Some("hash".to_string()),
321 };
322
323 let result = manager.add_operation(tx_id, op).await;
324 assert!(result.is_ok());
325
326 let transaction = manager.get_transaction(tx_id).await.unwrap();
327 assert_eq!(transaction.operations.len(), 1);
328 }
329
330 #[tokio::test]
331 async fn test_add_operation_to_non_pending_transaction() {
332 let temp_dir = TempDir::new().unwrap();
333 let backup_dir = temp_dir.path().join("backups");
334 let manager = TransactionManager::new(BackupManager::new(backup_dir, 10));
335
336 let tx_id = manager.begin_transaction().await.unwrap();
337
338 {
340 let mut transactions = manager.transactions.write().await;
341 if let Some(tx) = transactions.get_mut(&tx_id) {
342 tx.status = TransactionStatus::Committed;
343 }
344 }
345
346 let op = FileOperation {
347 path: std::path::PathBuf::from("test.txt"),
348 operation: crate::models::OperationType::Create,
349 content: Some("test content".to_string()),
350 backup_path: None,
351 content_hash: Some("hash".to_string()),
352 };
353
354 let result = manager.add_operation(tx_id, op).await;
355 assert!(result.is_err());
356 }
357
358 #[tokio::test]
359 async fn test_commit_single_operation() {
360 let temp_dir = TempDir::new().unwrap();
361 let backup_dir = temp_dir.path().join("backups");
362 let file_path = temp_dir.path().join("test.txt");
363
364 let manager = TransactionManager::new(BackupManager::new(backup_dir, 10));
365 let tx_id = manager.begin_transaction().await.unwrap();
366
367 let op = FileOperation {
368 path: file_path.clone(),
369 operation: crate::models::OperationType::Create,
370 content: Some("test content".to_string()),
371 backup_path: None,
372 content_hash: Some("hash".to_string()),
373 };
374
375 manager.add_operation(tx_id, op).await.unwrap();
376 let result = manager.commit(tx_id).await;
377
378 assert!(result.is_ok());
379 let status = manager.get_status(tx_id).await.unwrap();
380 assert_eq!(status, TransactionStatus::Committed);
381
382 let content = fs::read_to_string(&file_path).await.unwrap();
384 assert_eq!(content, "test content");
385 }
386
387 #[tokio::test]
388 async fn test_commit_multiple_operations() {
389 let temp_dir = TempDir::new().unwrap();
390 let backup_dir = temp_dir.path().join("backups");
391 let file1 = temp_dir.path().join("file1.txt");
392 let file2 = temp_dir.path().join("file2.txt");
393
394 let manager = TransactionManager::new(BackupManager::new(backup_dir, 10));
395 let tx_id = manager.begin_transaction().await.unwrap();
396
397 let op1 = FileOperation {
398 path: file1.clone(),
399 operation: crate::models::OperationType::Create,
400 content: Some("content 1".to_string()),
401 backup_path: None,
402 content_hash: Some("hash1".to_string()),
403 };
404
405 let op2 = FileOperation {
406 path: file2.clone(),
407 operation: crate::models::OperationType::Create,
408 content: Some("content 2".to_string()),
409 backup_path: None,
410 content_hash: Some("hash2".to_string()),
411 };
412
413 manager.add_operation(tx_id, op1).await.unwrap();
414 manager.add_operation(tx_id, op2).await.unwrap();
415
416 let result = manager.commit(tx_id).await;
417 assert!(result.is_ok());
418
419 let content1 = fs::read_to_string(&file1).await.unwrap();
421 let content2 = fs::read_to_string(&file2).await.unwrap();
422 assert_eq!(content1, "content 1");
423 assert_eq!(content2, "content 2");
424 }
425
426 #[tokio::test]
427 async fn test_rollback_restores_files() {
428 let temp_dir = TempDir::new().unwrap();
429 let backup_dir = temp_dir.path().join("backups");
430 let file_path = temp_dir.path().join("test.txt");
431
432 fs::write(&file_path, "original content").await.unwrap();
434
435 let manager = TransactionManager::new(BackupManager::new(backup_dir, 10));
436 let tx_id = manager.begin_transaction().await.unwrap();
437
438 let op = FileOperation {
439 path: file_path.clone(),
440 operation: crate::models::OperationType::Update,
441 content: Some("new content".to_string()),
442 backup_path: None,
443 content_hash: Some("hash".to_string()),
444 };
445
446 manager.add_operation(tx_id, op).await.unwrap();
447 manager.commit(tx_id).await.unwrap();
448
449 let content = fs::read_to_string(&file_path).await.unwrap();
451 assert_eq!(content, "new content");
452
453 let result = manager.rollback(tx_id).await;
455 assert!(result.is_ok());
456
457 let content = fs::read_to_string(&file_path).await.unwrap();
459 assert_eq!(content, "original content");
460 }
461
462 #[tokio::test]
463 async fn test_rollback_deletes_created_files() {
464 let temp_dir = TempDir::new().unwrap();
465 let backup_dir = temp_dir.path().join("backups");
466 let file_path = temp_dir.path().join("new_file.txt");
467
468 let manager = TransactionManager::new(BackupManager::new(backup_dir, 10));
469 let tx_id = manager.begin_transaction().await.unwrap();
470
471 let op = FileOperation {
472 path: file_path.clone(),
473 operation: crate::models::OperationType::Create,
474 content: Some("new content".to_string()),
475 backup_path: None,
476 content_hash: Some("hash".to_string()),
477 };
478
479 manager.add_operation(tx_id, op).await.unwrap();
480 manager.commit(tx_id).await.unwrap();
481
482 assert!(file_path.exists());
484
485 let result = manager.rollback(tx_id).await;
487 assert!(result.is_ok());
488
489 assert!(!file_path.exists());
491 }
492
493 #[tokio::test]
494 async fn test_get_transaction() {
495 let manager = TransactionManager::default();
496 let tx_id = manager.begin_transaction().await.unwrap();
497
498 let transaction = manager.get_transaction(tx_id).await.unwrap();
499 assert_eq!(transaction.id, tx_id);
500 assert_eq!(transaction.status, TransactionStatus::Pending);
501 assert_eq!(transaction.operations.len(), 0);
502 }
503
504 #[tokio::test]
505 async fn test_get_nonexistent_transaction() {
506 let manager = TransactionManager::default();
507 let fake_id = Uuid::new_v4();
508
509 let result = manager.get_transaction(fake_id).await;
510 assert!(result.is_err());
511 }
512}