pkgs/runner/
rollback.rs

1use super::{Runner, RunnerError};
2use crate::logger::{LogMessage, LoggerOutput};
3
4impl<O: LoggerOutput> Runner<O> {
5    pub fn rollback(&mut self) -> Result<(), RunnerError> {
6        let Some(actions) = self.last_action() else {
7            return Err(RunnerError::NoActionToRollback);
8        };
9
10        let (head, actions) = actions.split_first().unwrap();
11        match head {
12            LogMessage::LoadModule(module) => self.logger.rollback_load_module(module),
13            LogMessage::UnloadModule(module) => self.logger.rollback_unload_module(module),
14            _ => unreachable!(),
15        }
16
17        for action in actions.iter().rev() {
18            match action {
19                LogMessage::LoadModule(_)
20                | LogMessage::UnloadModule(_)
21                | LogMessage::RollbackLoadModule(_)
22                | LogMessage::RollbackUnloadModule(_) => unreachable!(),
23
24                LogMessage::CreateDir(path) => self.remove_dir(path)?,
25                LogMessage::CreateSymlink { src, dst } => self.remove_symlink(src, dst)?,
26
27                LogMessage::RemoveDir(path) => self.create_dir(path)?,
28                LogMessage::RemoveSymlink { src, dst } => self.create_symlink(src, dst)?,
29            }
30        }
31
32        Ok(())
33    }
34
35    fn last_action(&self) -> Option<Vec<LogMessage>> {
36        let msgs = self.messages();
37        for i in (0..msgs.len()).rev() {
38            match &msgs[i] {
39                LogMessage::LoadModule(_) | LogMessage::UnloadModule(_) => {
40                    return Some(msgs[i..].to_vec());
41                }
42                LogMessage::RollbackLoadModule(_) | LogMessage::RollbackUnloadModule(_) => {
43                    return None;
44                }
45                _ => {}
46            }
47        }
48        None
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use std::fs;
55
56    use crate::test_utils::prelude::*;
57
58    #[gtest]
59    fn nothing_to_rollback() -> Result<()> {
60        let (_td, _pkg, mut runner) = common_local_pkg()?;
61        let err = runner.rollback().unwrap_err();
62        expect_that!(err, pat!(RunnerError::NoActionToRollback));
63        Ok(())
64    }
65
66    #[gtest]
67    fn rollback_twice() -> Result<()> {
68        let (_td, pkg, mut runner) = common_local_pkg()?;
69        runner.load_module(&pkg, None)?;
70        runner.rollback()?;
71
72        let err = runner.rollback().unwrap_err();
73        expect_that!(err, pat!(RunnerError::NoActionToRollback));
74
75        Ok(())
76    }
77
78    mod rollback_load_module {
79        use super::*;
80
81        #[gtest]
82        fn after_success() -> Result<()> {
83            let (td, pkg, mut runner) = common_local_pkg()?;
84
85            runner.load_module(&pkg, None)?;
86            let msgs = runner.messages()[1..].to_vec();
87            let rollback_begin = runner.messages().len();
88
89            runner.rollback()?;
90            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
91
92            expect_that!(
93                rollback_msgs[0],
94                pat!(LogMessage::RollbackLoadModule("test_package"))
95            );
96            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
97
98            expect_pred!(!td.join(DST_DIR_PATH).exists());
99            expect_pred!(!td.join(DST_FILE_PATH).exists());
100
101            expect_that!(
102                rollback_msgs,
103                superset_of([
104                    &LogMessage::RemoveSymlink {
105                        src: td.join(SRC_FILE_PATH).canonicalize()?,
106                        dst: td.join(DST_FILE_PATH)
107                    },
108                    &LogMessage::RemoveDir(td.join("./test_pkg")),
109                ])
110            );
111            expect_that!(
112                rollback_msgs,
113                superset_of([
114                    &LogMessage::RemoveSymlink {
115                        src: td.join(SRC_DIR_PATH).canonicalize()?,
116                        dst: td.join(DST_DIR_PATH)
117                    },
118                    &LogMessage::RemoveDir(td.join("./test_a/test_b")),
119                ])
120            );
121
122            Ok(())
123        }
124
125        #[gtest]
126        fn after_failure() -> Result<()> {
127            let (td, pkg, mut runner) = common_local_pkg()?;
128            fs::remove_file(td.join(SRC_FILE_PATH))?;
129
130            let _ = runner.load_module(&pkg, None).unwrap_err();
131            let msgs = runner.messages()[1..].to_vec();
132            let rollback_begin = runner.messages().len();
133            let load_src_dir = td.join(DST_DIR_PATH).exists();
134
135            runner.rollback()?;
136            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
137
138            expect_that!(
139                rollback_msgs[0],
140                pat!(LogMessage::RollbackLoadModule("test_package"))
141            );
142            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
143
144            expect_pred!(!td.join(DST_DIR_PATH).exists());
145            expect_pred!(!td.join(DST_FILE_PATH).exists());
146
147            if load_src_dir {
148                expect_that!(
149                    rollback_msgs,
150                    superset_of([
151                        &LogMessage::RemoveSymlink {
152                            src: td.join(SRC_DIR_PATH).canonicalize()?,
153                            dst: td.join(DST_DIR_PATH)
154                        },
155                        &LogMessage::RemoveDir(td.join("./test_a/test_b")),
156                    ])
157                );
158            }
159
160            Ok(())
161        }
162
163        #[gtest]
164        fn only_rollback_last_loading() -> Result<()> {
165            let (td, mut pkg, mut runner) = common_local_pkg()?;
166            let trace = runner.load_module(&pkg, None)?;
167
168            let new_src_file = "test_package/new_src_file";
169            let td = td.file(new_src_file, "")?;
170            pkg.insert_map("new_src_file", td.join("new_dst_file").to_string_lossy());
171
172            let mut runner = common_runner(td.path());
173            runner.load_module(&pkg, Some(&trace))?;
174
175            let msgs = runner.messages()[1..].to_vec();
176            let rollback_begin = runner.messages().len();
177
178            runner.rollback()?;
179            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
180
181            expect_that!(
182                rollback_msgs[0],
183                pat!(LogMessage::RollbackLoadModule("test_package"))
184            );
185            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
186
187            expect_pred!(td.join(DST_DIR_PATH).exists());
188            expect_pred!(td.join(DST_FILE_PATH).exists());
189            expect_pred!(!td.join("new_dst_file").exists());
190
191            expect_that!(
192                rollback_msgs,
193                superset_of([&LogMessage::RemoveSymlink {
194                    src: td.join(new_src_file).canonicalize()?,
195                    dst: td.join("new_dst_file")
196                },])
197            );
198
199            Ok(())
200        }
201    }
202
203    mod rollback_unload_module {
204        use super::*;
205
206        #[gtest]
207        fn after_success() -> Result<()> {
208            let (td, pkg, mut runner) = common_local_pkg()?;
209            let trace = runner.load_module(&pkg, None)?;
210
211            let mut runner = common_runner(td.path());
212            runner.unload_module("test_package", &trace)?;
213
214            let msgs = runner.messages()[1..].to_vec();
215            let rollback_begin = runner.messages().len();
216
217            runner.rollback()?;
218            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
219
220            expect_that!(
221                rollback_msgs[0],
222                pat!(LogMessage::RollbackUnloadModule("test_package"))
223            );
224            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
225
226            expect_pred!(td.join(DST_DIR_PATH).exists());
227            expect_pred!(td.join(DST_FILE_PATH).exists());
228
229            expect_that!(
230                rollback_msgs,
231                contains(pat!(LogMessage::CreateSymlink {
232                    src: &td.join(SRC_DIR_PATH),
233                    dst: &td.join(DST_DIR_PATH),
234                }))
235            );
236            expect_that!(
237                rollback_msgs,
238                contains(pat!(LogMessage::CreateSymlink {
239                    src: &td.join(SRC_FILE_PATH),
240                    dst: &td.join(DST_FILE_PATH),
241                }))
242            );
243
244            Ok(())
245        }
246
247        #[gtest]
248        fn after_failure() -> Result<()> {
249            let (td, pkg, mut runner) = common_local_pkg()?;
250            let trace = runner.load_module(&pkg, None)?;
251            fs::remove_file(td.join(DST_FILE_PATH))?;
252
253            let mut runner = common_runner(td.path());
254            let _ = runner.unload_module("test_package", &trace).unwrap_err();
255
256            let msgs = runner.messages()[1..].to_vec();
257            let rollback_begin = runner.messages().len();
258            let unload_src_dir = runner.messages().contains(&LogMessage::RemoveSymlink {
259                src: td.join(SRC_DIR_PATH),
260                dst: td.join(DST_DIR_PATH),
261            });
262
263            runner.rollback()?;
264            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
265
266            expect_that!(
267                rollback_msgs[0],
268                pat!(LogMessage::RollbackUnloadModule("test_package"))
269            );
270            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
271
272            expect_pred!(td.join(DST_DIR_PATH).exists());
273
274            if unload_src_dir {
275                expect_that!(
276                    rollback_msgs,
277                    contains(pat!(LogMessage::CreateSymlink {
278                        src: &td.join(SRC_DIR_PATH),
279                        dst: &td.join(DST_DIR_PATH),
280                    }))
281                );
282            }
283
284            Ok(())
285        }
286    }
287}