mixtape_tools/sqlite/transaction/
rollback.rs1use crate::prelude::*;
4use crate::sqlite::manager::with_connection;
5
6#[derive(Debug, Deserialize, JsonSchema)]
8pub struct RollbackTransactionInput {
9 #[serde(default)]
11 pub db_path: Option<String>,
12}
13
14pub struct RollbackTransactionTool;
19
20impl Tool for RollbackTransactionTool {
21 type Input = RollbackTransactionInput;
22
23 fn name(&self) -> &str {
24 "sqlite_rollback_transaction"
25 }
26
27 fn description(&self) -> &str {
28 "Rollback the current transaction, reverting all changes made since the transaction began."
29 }
30
31 async fn execute(&self, input: Self::Input) -> Result<ToolResult, ToolError> {
32 with_connection(input.db_path, |conn| {
33 conn.execute("ROLLBACK", [])?;
34 Ok(())
35 })
36 .await?;
37
38 let response = serde_json::json!({
39 "status": "success",
40 "message": "Transaction rolled back successfully"
41 });
42 Ok(ToolResult::Json(response))
43 }
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49 use crate::sqlite::test_utils::TestDatabase;
50 use crate::sqlite::transaction::BeginTransactionTool;
51
52 #[tokio::test]
53 async fn test_rollback_transaction() {
54 let db = TestDatabase::with_schema("CREATE TABLE test (id INTEGER);").await;
55
56 let begin_tool = BeginTransactionTool;
58 let begin_input = crate::sqlite::transaction::begin::BeginTransactionInput {
59 db_path: Some(db.key()),
60 transaction_type: crate::sqlite::transaction::begin::TransactionType::Deferred,
61 };
62 begin_tool.execute(begin_input).await.unwrap();
63
64 db.execute("INSERT INTO test VALUES (1)");
66 db.execute("INSERT INTO test VALUES (2)");
67
68 let tool = RollbackTransactionTool;
70 let input = RollbackTransactionInput {
71 db_path: Some(db.key()),
72 };
73
74 let result = tool.execute(input).await;
75 assert!(result.is_ok());
76
77 assert_eq!(db.count("test"), 0);
79 }
80
81 #[test]
82 fn test_tool_metadata() {
83 let tool = RollbackTransactionTool;
84 assert_eq!(tool.name(), "sqlite_rollback_transaction");
85 assert!(!tool.description().is_empty());
86 }
87}