1use super::{Runner, RunnerError};
2use crate::logger::{LogMessage, LoggerOutput};
3
4impl<O: LoggerOutput> Runner<O> {
5 pub fn rollback(&mut self) -> Result<(), RunnerError> {
6 let Some(actions) = self.last_action() else {
7 return Err(RunnerError::NoActionToRollback);
8 };
9
10 let (head, actions) = actions.split_first().unwrap();
11 match head {
12 LogMessage::LoadModule(module) => self.logger.rollback_load_module(module),
13 LogMessage::UnloadModule(module) => self.logger.rollback_unload_module(module),
14 _ => unreachable!(),
15 }
16
17 for action in actions.iter().rev() {
18 match action {
19 LogMessage::LoadModule(_)
20 | LogMessage::UnloadModule(_)
21 | LogMessage::RollbackLoadModule(_)
22 | LogMessage::RollbackUnloadModule(_) => unreachable!(),
23
24 LogMessage::CreateDir(path) => self.remove_dir(path)?,
25 LogMessage::CreateSymlink { src, dst } => self.remove_symlink(src, dst)?,
26
27 LogMessage::RemoveDir(path) => self.create_dir(path)?,
28 LogMessage::RemoveSymlink { src, dst } => self.create_symlink(src, dst)?,
29 }
30 }
31
32 Ok(())
33 }
34
35 fn last_action(&self) -> Option<Vec<LogMessage>> {
36 let msgs = self.messages();
37 for i in (0..msgs.len()).rev() {
38 match &msgs[i] {
39 LogMessage::LoadModule(_) | LogMessage::UnloadModule(_) => {
40 return Some(msgs[i..].to_vec());
41 }
42 LogMessage::RollbackLoadModule(_) | LogMessage::RollbackUnloadModule(_) => {
43 return None;
44 }
45 _ => {}
46 }
47 }
48 None
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use std::fs;
55
56 use crate::test_utils::prelude::*;
57
58 #[gtest]
59 fn nothing_to_rollback() -> Result<()> {
60 let (_td, _pkg, mut runner) = common_local_pkg()?;
61 let err = runner.rollback().unwrap_err();
62 expect_that!(err, pat!(RunnerError::NoActionToRollback));
63 Ok(())
64 }
65
66 #[gtest]
67 fn rollback_twice() -> Result<()> {
68 let (_td, pkg, mut runner) = common_local_pkg()?;
69 runner.load_module(&pkg, None)?;
70 runner.rollback()?;
71
72 let err = runner.rollback().unwrap_err();
73 expect_that!(err, pat!(RunnerError::NoActionToRollback));
74
75 Ok(())
76 }
77
78 mod rollback_load_module {
79 use super::*;
80
81 #[gtest]
82 fn after_success() -> Result<()> {
83 let (td, pkg, mut runner) = common_local_pkg()?;
84
85 runner.load_module(&pkg, None)?;
86 let rollback_begin = runner.messages().len();
87
88 runner.rollback()?;
89
90 expect_pred!(!td.join(DST_DIR_PATH).exists());
91 expect_pred!(!td.join(DST_FILE_PATH).exists());
92
93 expect_eq!(
94 runner.messages()[rollback_begin..],
95 [
96 LogMessage::RollbackLoadModule("test_package".into()),
97 LogMessage::RemoveSymlink {
98 src: td.join(SRC_DIR_PATH).canonicalize()?,
99 dst: td.join(DST_DIR_PATH)
100 },
101 LogMessage::RemoveDir(td.join("./test_a/test_b")),
102 LogMessage::RemoveSymlink {
103 src: td.join(SRC_FILE_PATH).canonicalize()?,
104 dst: td.join(DST_FILE_PATH)
105 },
106 LogMessage::RemoveDir(td.join("./test_pkg")),
107 ]
108 );
109
110 Ok(())
111 }
112
113 #[gtest]
114 fn after_failure() -> Result<()> {
115 let (td, pkg, mut runner) = common_local_pkg()?;
116 fs::remove_dir(td.join(SRC_DIR_PATH))?;
117
118 let _ = runner.load_module(&pkg, None).unwrap_err();
119 let rollback_begin = runner.messages().len();
120
121 runner.rollback()?;
122
123 expect_pred!(!td.join(DST_DIR_PATH).exists());
124 expect_pred!(!td.join(DST_FILE_PATH).exists());
125
126 expect_eq!(
127 runner.messages()[rollback_begin..].to_vec(),
128 [
129 LogMessage::RollbackLoadModule("test_package".into()),
130 LogMessage::RemoveSymlink {
131 src: td.join(SRC_FILE_PATH).canonicalize()?,
132 dst: td.join(DST_FILE_PATH)
133 },
134 LogMessage::RemoveDir(td.join("./test_pkg")),
135 ]
136 );
137
138 Ok(())
139 }
140
141 #[gtest]
142 fn only_rollback_last_loading() -> Result<()> {
143 let (td, mut pkg, mut runner) = common_local_pkg()?;
144 let trace = runner.load_module(&pkg, None)?;
145
146 let new_src_file = "test_package/new_src_file";
147 let td = td.file(new_src_file, "")?;
148 pkg.insert_map("new_src_file", td.join("new_dst_file").to_string_lossy());
149
150 let mut runner = common_runner(td.path());
151 runner.load_module(&pkg, Some(&trace))?;
152
153 let rollback_begin = runner.messages().len();
154
155 runner.rollback()?;
156
157 expect_pred!(td.join(DST_DIR_PATH).exists());
158 expect_pred!(td.join(DST_FILE_PATH).exists());
159 expect_pred!(!td.join("new_dst_file").exists());
160
161 expect_eq!(
162 runner.messages()[rollback_begin..].to_vec(),
163 [
164 LogMessage::RollbackLoadModule("test_package".into()),
165 LogMessage::RemoveSymlink {
166 src: td.join(new_src_file).canonicalize()?,
167 dst: td.join("new_dst_file")
168 }
169 ]
170 );
171
172 Ok(())
173 }
174 }
175
176 mod rollback_unload_module {
177 use super::*;
178
179 #[gtest]
180 fn after_success() -> Result<()> {
181 let (td, pkg, mut runner) = common_local_pkg()?;
182 let trace = runner.load_module(&pkg, None)?;
183
184 let mut runner = common_runner(td.path());
185 runner.unload_module("test_package", &trace)?;
186
187 let msgs = runner.messages()[1..].to_vec();
188 let rollback_begin = runner.messages().len();
189
190 runner.rollback()?;
191 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
192
193 expect_that!(
194 rollback_msgs[0],
195 pat!(LogMessage::RollbackUnloadModule("test_package"))
196 );
197 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
198
199 expect_pred!(td.join(DST_DIR_PATH).exists());
200 expect_pred!(td.join(DST_FILE_PATH).exists());
201
202 expect_that!(
203 rollback_msgs,
204 contains(pat!(LogMessage::CreateSymlink {
205 src: &td.join(SRC_DIR_PATH),
206 dst: &td.join(DST_DIR_PATH),
207 }))
208 );
209 expect_that!(
210 rollback_msgs,
211 contains(pat!(LogMessage::CreateSymlink {
212 src: &td.join(SRC_FILE_PATH),
213 dst: &td.join(DST_FILE_PATH),
214 }))
215 );
216
217 Ok(())
218 }
219
220 #[gtest]
221 fn after_failure() -> Result<()> {
222 let (td, pkg, mut runner) = common_local_pkg()?;
223 let trace = runner.load_module(&pkg, None)?;
224 fs::remove_file(td.join(DST_FILE_PATH))?;
225
226 let mut runner = common_runner(td.path());
227 let _ = runner.unload_module("test_package", &trace).unwrap_err();
228
229 let msgs = runner.messages()[1..].to_vec();
230 let rollback_begin = runner.messages().len();
231 let unload_src_dir = runner.messages().contains(&LogMessage::RemoveSymlink {
232 src: td.join(SRC_DIR_PATH),
233 dst: td.join(DST_DIR_PATH),
234 });
235
236 runner.rollback()?;
237 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
238
239 expect_that!(
240 rollback_msgs[0],
241 pat!(LogMessage::RollbackUnloadModule("test_package"))
242 );
243 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
244
245 expect_pred!(td.join(DST_DIR_PATH).exists());
246
247 if unload_src_dir {
248 expect_that!(
249 rollback_msgs,
250 contains(pat!(LogMessage::CreateSymlink {
251 src: &td.join(SRC_DIR_PATH),
252 dst: &td.join(DST_DIR_PATH),
253 }))
254 );
255 }
256
257 Ok(())
258 }
259 }
260}