pkgs/runner/
rollback.rs

1use super::{Runner, RunnerError};
2use crate::logger::{LogMessage, LoggerOutput};
3
4impl<O: LoggerOutput> Runner<O> {
5    pub fn rollback(&mut self) -> Result<(), RunnerError> {
6        let Some(actions) = self.last_action() else {
7            return Err(RunnerError::NoActionToRollback);
8        };
9
10        let (head, actions) = actions.split_first().unwrap();
11        match head {
12            LogMessage::LoadModule(module) => self.logger.rollback_load_module(module),
13            LogMessage::UnloadModule(module) => self.logger.rollback_unload_module(module),
14            _ => unreachable!(),
15        }
16
17        for action in actions.iter().rev() {
18            match action {
19                LogMessage::LoadModule(_)
20                | LogMessage::UnloadModule(_)
21                | LogMessage::RollbackLoadModule(_)
22                | LogMessage::RollbackUnloadModule(_) => unreachable!(),
23
24                LogMessage::CreateDir(path) => self.remove_dir(path)?,
25                LogMessage::CreateSymlink { src, dst } => self.remove_symlink(src, dst)?,
26
27                LogMessage::RemoveDir(path) => self.create_dir(path)?,
28                LogMessage::RemoveSymlink { src, dst } => self.create_symlink(src, dst)?,
29            }
30        }
31
32        Ok(())
33    }
34
35    fn last_action(&self) -> Option<Vec<LogMessage>> {
36        let msgs = self.messages();
37        for i in (0..msgs.len()).rev() {
38            match &msgs[i] {
39                LogMessage::LoadModule(_) | LogMessage::UnloadModule(_) => {
40                    return Some(msgs[i..].to_vec());
41                }
42                LogMessage::RollbackLoadModule(_) | LogMessage::RollbackUnloadModule(_) => {
43                    return None;
44                }
45                _ => {}
46            }
47        }
48        None
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use std::fs;
55
56    use crate::test_utils::prelude::*;
57
58    #[gtest]
59    fn nothing_to_rollback() -> Result<()> {
60        let (_td, _pkg, mut runner) = common_local_pkg()?;
61        let err = runner.rollback().unwrap_err();
62        expect_that!(err, pat!(RunnerError::NoActionToRollback));
63        Ok(())
64    }
65
66    #[gtest]
67    fn rollback_twice() -> Result<()> {
68        let (_td, pkg, mut runner) = common_local_pkg()?;
69        runner.load_module(&pkg, None)?;
70        runner.rollback()?;
71
72        let err = runner.rollback().unwrap_err();
73        expect_that!(err, pat!(RunnerError::NoActionToRollback));
74
75        Ok(())
76    }
77
78    mod rollback_load_module {
79        use super::*;
80
81        #[gtest]
82        fn after_success() -> Result<()> {
83            let (td, pkg, mut runner) = common_local_pkg()?;
84
85            runner.load_module(&pkg, None)?;
86            let rollback_begin = runner.messages().len();
87
88            runner.rollback()?;
89
90            expect_pred!(!td.join(DST_DIR_PATH).exists());
91            expect_pred!(!td.join(DST_FILE_PATH).exists());
92
93            expect_eq!(
94                runner.messages()[rollback_begin..],
95                [
96                    LogMessage::RollbackLoadModule("test_package".into()),
97                    LogMessage::RemoveSymlink {
98                        src: td.join(SRC_DIR_PATH).canonicalize()?,
99                        dst: td.join(DST_DIR_PATH)
100                    },
101                    LogMessage::RemoveDir(td.join("./test_a/test_b")),
102                    LogMessage::RemoveSymlink {
103                        src: td.join(SRC_FILE_PATH).canonicalize()?,
104                        dst: td.join(DST_FILE_PATH)
105                    },
106                    LogMessage::RemoveDir(td.join("./test_pkg")),
107                ]
108            );
109
110            Ok(())
111        }
112
113        #[gtest]
114        fn after_failure() -> Result<()> {
115            let (td, pkg, mut runner) = common_local_pkg()?;
116            fs::remove_dir(td.join(SRC_DIR_PATH))?;
117
118            let _ = runner.load_module(&pkg, None).unwrap_err();
119            let rollback_begin = runner.messages().len();
120
121            runner.rollback()?;
122
123            expect_pred!(!td.join(DST_DIR_PATH).exists());
124            expect_pred!(!td.join(DST_FILE_PATH).exists());
125
126            expect_eq!(
127                runner.messages()[rollback_begin..].to_vec(),
128                [
129                    LogMessage::RollbackLoadModule("test_package".into()),
130                    LogMessage::RemoveSymlink {
131                        src: td.join(SRC_FILE_PATH).canonicalize()?,
132                        dst: td.join(DST_FILE_PATH)
133                    },
134                    LogMessage::RemoveDir(td.join("./test_pkg")),
135                ]
136            );
137
138            Ok(())
139        }
140
141        #[gtest]
142        fn only_rollback_last_loading() -> Result<()> {
143            let (td, mut pkg, mut runner) = common_local_pkg()?;
144            let trace = runner.load_module(&pkg, None)?;
145
146            let new_src_file = "test_package/new_src_file";
147            let td = td.file(new_src_file, "")?;
148            pkg.insert_map("new_src_file", td.join("new_dst_file").to_string_lossy());
149
150            let mut runner = common_runner(td.path());
151            runner.load_module(&pkg, Some(&trace))?;
152
153            let rollback_begin = runner.messages().len();
154
155            runner.rollback()?;
156
157            expect_pred!(td.join(DST_DIR_PATH).exists());
158            expect_pred!(td.join(DST_FILE_PATH).exists());
159            expect_pred!(!td.join("new_dst_file").exists());
160
161            expect_eq!(
162                runner.messages()[rollback_begin..].to_vec(),
163                [
164                    LogMessage::RollbackLoadModule("test_package".into()),
165                    LogMessage::RemoveSymlink {
166                        src: td.join(new_src_file).canonicalize()?,
167                        dst: td.join("new_dst_file")
168                    }
169                ]
170            );
171
172            Ok(())
173        }
174    }
175
176    mod rollback_unload_module {
177        use super::*;
178
179        #[gtest]
180        fn after_success() -> Result<()> {
181            let (td, pkg, mut runner) = common_local_pkg()?;
182            let trace = runner.load_module(&pkg, None)?;
183
184            let mut runner = common_runner(td.path());
185            runner.unload_module("test_package", &trace)?;
186
187            let msgs = runner.messages()[1..].to_vec();
188            let rollback_begin = runner.messages().len();
189
190            runner.rollback()?;
191            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
192
193            expect_that!(
194                rollback_msgs[0],
195                pat!(LogMessage::RollbackUnloadModule("test_package"))
196            );
197            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
198
199            expect_pred!(td.join(DST_DIR_PATH).exists());
200            expect_pred!(td.join(DST_FILE_PATH).exists());
201
202            expect_that!(
203                rollback_msgs,
204                contains(pat!(LogMessage::CreateSymlink {
205                    src: &td.join(SRC_DIR_PATH),
206                    dst: &td.join(DST_DIR_PATH),
207                }))
208            );
209            expect_that!(
210                rollback_msgs,
211                contains(pat!(LogMessage::CreateSymlink {
212                    src: &td.join(SRC_FILE_PATH),
213                    dst: &td.join(DST_FILE_PATH),
214                }))
215            );
216
217            Ok(())
218        }
219
220        #[gtest]
221        fn after_failure() -> Result<()> {
222            let (td, pkg, mut runner) = common_local_pkg()?;
223            let trace = runner.load_module(&pkg, None)?;
224            fs::remove_file(td.join(DST_FILE_PATH))?;
225
226            let mut runner = common_runner(td.path());
227            let _ = runner.unload_module("test_package", &trace).unwrap_err();
228
229            let msgs = runner.messages()[1..].to_vec();
230            let rollback_begin = runner.messages().len();
231            let unload_src_dir = runner.messages().contains(&LogMessage::RemoveSymlink {
232                src: td.join(SRC_DIR_PATH),
233                dst: td.join(DST_DIR_PATH),
234            });
235
236            runner.rollback()?;
237            let rollback_msgs = runner.messages()[rollback_begin..].to_vec();
238
239            expect_that!(
240                rollback_msgs[0],
241                pat!(LogMessage::RollbackUnloadModule("test_package"))
242            );
243            expect_eq!(rollback_msgs.len(), msgs.len() + 1);
244
245            expect_pred!(td.join(DST_DIR_PATH).exists());
246
247            if unload_src_dir {
248                expect_that!(
249                    rollback_msgs,
250                    contains(pat!(LogMessage::CreateSymlink {
251                        src: &td.join(SRC_DIR_PATH),
252                        dst: &td.join(DST_DIR_PATH),
253                    }))
254                );
255            }
256
257            Ok(())
258        }
259    }
260}