1use super::{Runner, RunnerError};
2use crate::logger::{LogMessage, LoggerOutput};
3
4impl<O: LoggerOutput> Runner<O> {
5 pub fn rollback(&mut self) -> Result<(), RunnerError> {
6 let Some(actions) = self.last_action() else {
7 return Err(RunnerError::NoActionToRollback);
8 };
9
10 let (head, actions) = actions.split_first().unwrap();
11 match head {
12 LogMessage::LoadModule(module) => self.logger.rollback_load_module(module),
13 LogMessage::UnloadModule(module) => self.logger.rollback_unload_module(module),
14 _ => unreachable!(),
15 }
16
17 for action in actions.iter().rev() {
18 match action {
19 LogMessage::LoadModule(_)
20 | LogMessage::UnloadModule(_)
21 | LogMessage::RollbackLoadModule(_)
22 | LogMessage::RollbackUnloadModule(_) => unreachable!(),
23
24 LogMessage::CreateDir(path) => self.remove_dir(path)?,
25 LogMessage::CreateSymlink { src, dst } => self.remove_symlink(src, dst)?,
26
27 LogMessage::RemoveDir(path) => self.create_dir(path)?,
28 LogMessage::RemoveSymlink { src, dst } => self.create_symlink(src, dst)?,
29 }
30 }
31
32 Ok(())
33 }
34
35 fn last_action(&self) -> Option<Vec<LogMessage>> {
36 let msgs = self.messages();
37 for i in (0..msgs.len()).rev() {
38 match &msgs[i] {
39 LogMessage::LoadModule(_) | LogMessage::UnloadModule(_) => {
40 return Some(msgs[i..].to_vec());
41 }
42 LogMessage::RollbackLoadModule(_) | LogMessage::RollbackUnloadModule(_) => {
43 return None;
44 }
45 _ => {}
46 }
47 }
48 None
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use std::fs;
55
56 use crate::test_utils::prelude::*;
57
58 #[gtest]
59 fn nothing_to_rollback() -> Result<()> {
60 let (_td, _pkg, mut runner) = common_local_pkg()?;
61 let err = runner.rollback().unwrap_err();
62 expect_that!(err, pat!(RunnerError::NoActionToRollback));
63 Ok(())
64 }
65
66 #[gtest]
67 fn rollback_twice() -> Result<()> {
68 let (_td, pkg, mut runner) = common_local_pkg()?;
69 runner.load_module(&pkg, None)?;
70 runner.rollback()?;
71
72 let err = runner.rollback().unwrap_err();
73 expect_that!(err, pat!(RunnerError::NoActionToRollback));
74
75 Ok(())
76 }
77
78 mod rollback_load_module {
79 use super::*;
80
81 #[gtest]
82 fn after_success() -> Result<()> {
83 let (td, pkg, mut runner) = common_local_pkg()?;
84
85 runner.load_module(&pkg, None)?;
86 let msgs = runner.messages()[1..].to_vec();
87 let rollback_begin = runner.messages().len();
88
89 runner.rollback()?;
90 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
91
92 expect_that!(
93 rollback_msgs[0],
94 pat!(LogMessage::RollbackLoadModule("test_package"))
95 );
96 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
97
98 expect_pred!(!td.join(DST_DIR_PATH).exists());
99 expect_pred!(!td.join(DST_FILE_PATH).exists());
100
101 expect_that!(
102 rollback_msgs,
103 superset_of([
104 &LogMessage::RemoveSymlink {
105 src: td.join(SRC_FILE_PATH).canonicalize()?,
106 dst: td.join(DST_FILE_PATH)
107 },
108 &LogMessage::RemoveDir(td.join("./test_pkg")),
109 ])
110 );
111 expect_that!(
112 rollback_msgs,
113 superset_of([
114 &LogMessage::RemoveSymlink {
115 src: td.join(SRC_DIR_PATH).canonicalize()?,
116 dst: td.join(DST_DIR_PATH)
117 },
118 &LogMessage::RemoveDir(td.join("./test_a/test_b")),
119 ])
120 );
121
122 Ok(())
123 }
124
125 #[gtest]
126 fn after_failure() -> Result<()> {
127 let (td, pkg, mut runner) = common_local_pkg()?;
128 fs::remove_file(td.join(SRC_FILE_PATH))?;
129
130 let _ = runner.load_module(&pkg, None).unwrap_err();
131 let msgs = runner.messages()[1..].to_vec();
132 let rollback_begin = runner.messages().len();
133 let load_src_dir = td.join(DST_DIR_PATH).exists();
134
135 runner.rollback()?;
136 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
137
138 expect_that!(
139 rollback_msgs[0],
140 pat!(LogMessage::RollbackLoadModule("test_package"))
141 );
142 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
143
144 expect_pred!(!td.join(DST_DIR_PATH).exists());
145 expect_pred!(!td.join(DST_FILE_PATH).exists());
146
147 if load_src_dir {
148 expect_that!(
149 rollback_msgs,
150 superset_of([
151 &LogMessage::RemoveSymlink {
152 src: td.join(SRC_DIR_PATH).canonicalize()?,
153 dst: td.join(DST_DIR_PATH)
154 },
155 &LogMessage::RemoveDir(td.join("./test_a/test_b")),
156 ])
157 );
158 }
159
160 Ok(())
161 }
162
163 #[gtest]
164 fn only_rollback_last_loading() -> Result<()> {
165 let (td, mut pkg, mut runner) = common_local_pkg()?;
166 let trace = runner.load_module(&pkg, None)?;
167
168 let new_src_file = "test_package/new_src_file";
169 let td = td.file(new_src_file, "")?;
170 pkg.insert_map("new_src_file", td.join("new_dst_file").to_string_lossy());
171
172 let mut runner = common_runner(td.path());
173 runner.load_module(&pkg, Some(&trace))?;
174
175 let msgs = runner.messages()[1..].to_vec();
176 let rollback_begin = runner.messages().len();
177
178 runner.rollback()?;
179 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
180
181 expect_that!(
182 rollback_msgs[0],
183 pat!(LogMessage::RollbackLoadModule("test_package"))
184 );
185 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
186
187 expect_pred!(td.join(DST_DIR_PATH).exists());
188 expect_pred!(td.join(DST_FILE_PATH).exists());
189 expect_pred!(!td.join("new_dst_file").exists());
190
191 expect_that!(
192 rollback_msgs,
193 superset_of([&LogMessage::RemoveSymlink {
194 src: td.join(new_src_file).canonicalize()?,
195 dst: td.join("new_dst_file")
196 },])
197 );
198
199 Ok(())
200 }
201 }
202
203 mod rollback_unload_module {
204 use super::*;
205
206 #[gtest]
207 fn after_success() -> Result<()> {
208 let (td, pkg, mut runner) = common_local_pkg()?;
209 let trace = runner.load_module(&pkg, None)?;
210
211 let mut runner = common_runner(td.path());
212 runner.unload_module("test_package", &trace)?;
213
214 let msgs = runner.messages()[1..].to_vec();
215 let rollback_begin = runner.messages().len();
216
217 runner.rollback()?;
218 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
219
220 expect_that!(
221 rollback_msgs[0],
222 pat!(LogMessage::RollbackUnloadModule("test_package"))
223 );
224 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
225
226 expect_pred!(td.join(DST_DIR_PATH).exists());
227 expect_pred!(td.join(DST_FILE_PATH).exists());
228
229 expect_that!(
230 rollback_msgs,
231 contains(pat!(LogMessage::CreateSymlink {
232 src: &td.join(SRC_DIR_PATH),
233 dst: &td.join(DST_DIR_PATH),
234 }))
235 );
236 expect_that!(
237 rollback_msgs,
238 contains(pat!(LogMessage::CreateSymlink {
239 src: &td.join(SRC_FILE_PATH),
240 dst: &td.join(DST_FILE_PATH),
241 }))
242 );
243
244 Ok(())
245 }
246
247 #[gtest]
248 fn after_failure() -> Result<()> {
249 let (td, pkg, mut runner) = common_local_pkg()?;
250 let trace = runner.load_module(&pkg, None)?;
251 fs::remove_file(td.join(DST_FILE_PATH))?;
252
253 let mut runner = common_runner(td.path());
254 let _ = runner.unload_module("test_package", &trace).unwrap_err();
255
256 let msgs = runner.messages()[1..].to_vec();
257 let rollback_begin = runner.messages().len();
258 let unload_src_dir = runner.messages().contains(&LogMessage::RemoveSymlink {
259 src: td.join(SRC_DIR_PATH),
260 dst: td.join(DST_DIR_PATH),
261 });
262
263 runner.rollback()?;
264 let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
265
266 expect_that!(
267 rollback_msgs[0],
268 pat!(LogMessage::RollbackUnloadModule("test_package"))
269 );
270 expect_eq!(rollback_msgs.len(), msgs.len() + 1);
271
272 expect_pred!(td.join(DST_DIR_PATH).exists());
273
274 if unload_src_dir {
275 expect_that!(
276 rollback_msgs,
277 contains(pat!(LogMessage::CreateSymlink {
278 src: &td.join(SRC_DIR_PATH),
279 dst: &td.join(DST_DIR_PATH),
280 }))
281 );
282 }
283
284 Ok(())
285 }
286 }
287}