1use std::fs;
2use std::io::{self, BufRead, Read, Write};
3use std::path::Path;
4
5use anyhow::Result;
6use log::debug;
7
8use crate::util::fs_util::path_attack_check;
9use crate::{IOErr, IOErrType};
10
11fn remove_dir_recursive<O>(dir: &Path, stdout: &mut O) -> io::Result<()>
12where
13 O: Write,
14{
15 for entry in fs::read_dir(dir)? {
16 let entry = entry?;
17 let entry_path = entry.path();
18
19 if entry_path.is_dir() {
20 remove_dir_recursive(&entry_path, stdout)?;
21 } else {
22 fs::remove_file(&entry_path)?;
23 writeln!(stdout, "Removed file '{}'", entry_path.display())?;
24 }
25 }
26
27 fs::remove_dir(dir)?;
28 writeln!(stdout, "Removed directory '{}'", dir.display())?;
29 Ok(())
30}
31
32pub fn remove_io<I, O, E>(
33 root: &Path,
34 dist: &str,
35 recursive: bool,
36 force: bool,
37 stdin: &mut I,
38 stdout: &mut O,
39 stderr: &mut E,
40) -> Result<()>
41where
42 I: Read + BufRead,
43 O: Write,
44 E: Write,
45{
46 let mut dist_path = root.join(dist);
47 path_attack_check(root, &dist_path)?;
48
49 if !dist_path.exists() || !dist_path.is_dir() {
50 debug!("Try to delete dir {:?}, which not exist", dist_path);
51 dist_path = root.join(format!("{}.gpg", dist));
52 if !dist_path.exists() || !dist_path.is_file() {
53 if force {
54 writeln!(stdout, "Noting to remove")?;
55 return Ok(());
56 }
57 debug!("Try to delete file {:?}, which not exist", dist_path);
58 writeln!(stderr, "Cannot remove '{}': No such file or directory", dist)?;
59 return Err(IOErr::new(IOErrType::PathNotExist, &dist_path).into());
60 }
61 }
62
63 if !force {
64 let confirm_msg = format!(
65 "Are you sure you would like to delete '{}' in repo '{}'? [y/N]: ",
66 dist,
67 root.display()
68 );
69 write!(stdout, "{}", confirm_msg)?;
70 stdout.flush()?;
71 let mut input = String::new();
72 stdin.read_line(&mut input)?;
73 if !input.trim().to_lowercase().starts_with('y') {
74 return Ok(());
75 }
76 }
77
78 if dist_path.is_file() {
79 fs::remove_file(&dist_path)?;
80 writeln!(stderr, "Removed '{}'", dist)?;
81 } else if dist_path.is_dir() {
82 if recursive {
83 remove_dir_recursive(&dist_path, stdout)?;
84 } else {
85 let err_msg = format!("Cannot remove '{}': Is a directory.", dist);
86 writeln!(stderr, "{}", err_msg)?;
87 return Err(IOErr::new(IOErrType::ExpectFile, &dist_path).into());
88 }
89 } else {
90 let err_msg = format!("Cannot remove '{}': Not a file or directory.", dist);
91 writeln!(stderr, "{}", err_msg)?;
92 return Err(IOErr::new(IOErrType::InvalidFileType, &dist_path).into());
93 }
94
95 Ok(())
96}
97
98#[cfg(test)]
99mod test {
100 use core::panic;
101 use std::io::BufReader;
102 use std::thread::sleep;
103 use std::time::Duration;
104 use std::{io, thread};
105
106 use os_pipe::pipe;
107 use pretty_assertions::assert_eq;
108
109 use super::*;
110 use crate::util::defer::cleanup;
111 use crate::util::fs_util::set_readonly;
112 use crate::util::test_util::{create_dir_structure, gen_unique_temp_dir};
113
114 fn enter_input_with_delay<T>(
115 input_str: &str,
116 delay: Duration,
117 mut stdin_writer: T,
118 ) -> thread::JoinHandle<()>
119 where
120 T: Write + Send + 'static,
121 {
122 let input = input_str.to_string();
123 thread::spawn(move || {
124 sleep(delay);
125 stdin_writer.write_all(input.as_bytes()).unwrap();
126 })
127 }
128
129 #[test]
130 fn remove_io_test() {
131 let (_tmp_dir, root) = gen_unique_temp_dir();
139 let structure: &[(Option<&str>, &[&str])] = &[
140 (Some("dir1"), &["file1.gpg", "file2.gpg"]),
141 (Some("dir2"), &[]),
142 (None, &["file3.gpg"]),
143 ];
144 create_dir_structure(&root, structure);
145 set_readonly(root.join("file3.gpg"), true).unwrap();
146 set_readonly(root.join("dir1").join("file1.gpg"), true).unwrap();
147
148 cleanup!(
149 {
150 let mut stdout = io::stdout().lock();
151 let mut stderr = io::stderr().lock();
152
153 let dist = "file3";
155 let (stdin, stdin_w) = pipe().unwrap();
156 let mut stdin = BufReader::new(stdin);
157
158 let input_thread =
159 enter_input_with_delay("n\n", Duration::from_millis(100), stdin_w);
160 remove_io(&root, dist, false, false, &mut stdin, &mut stdout, &mut stderr).unwrap();
161 assert_eq!(true, root.join(dist).with_extension("gpg").exists());
162 input_thread.join().unwrap();
163
164 remove_io(&root, dist, false, true, &mut stdin, &mut stdout, &mut stderr).unwrap();
165 assert_eq!(false, root.join(dist).exists());
166
167 let dist = "dir2";
170 let (stdin, stdin_w) = pipe().unwrap();
171 let mut stdin = BufReader::new(stdin);
172
173 let input_thread =
174 enter_input_with_delay("y\n", Duration::from_millis(100), stdin_w);
175 if remove_io(&root, dist, false, false, &mut stdin, &mut stdout, &mut stderr)
176 .is_ok()
177 {
178 panic!("Expect fail to remove a non-empty directory without recursive option.");
179 }
180 input_thread.join().unwrap();
181
182 let (stdin, stdin_w) = pipe().unwrap();
184 let mut stdin = BufReader::new(stdin);
185
186 let input_thread =
187 enter_input_with_delay("y\n", Duration::from_millis(100), stdin_w);
188 remove_io(&root, dist, true, false, &mut stdin, &mut stdout, &mut stderr).unwrap();
189 assert_eq!(false, root.join(dist).exists());
190 input_thread.join().unwrap();
191
192 let dist = "dir1";
194 let (stdin, stdin_w) = pipe().unwrap();
195 let mut stdin = BufReader::new(stdin);
196
197 let input_thread =
198 enter_input_with_delay("y\n", Duration::from_millis(100), stdin_w);
199 remove_io(&root, dist, true, false, &mut stdin, &mut stdout, &mut stderr).unwrap();
200 assert_eq!(false, root.join(dist).exists());
201 input_thread.join().unwrap();
202
203 let dist = "non-exist-file";
205 let (stdin, stdin_w) = pipe().unwrap();
206 let mut stdin = BufReader::new(stdin);
207
208 let input_thread =
209 enter_input_with_delay("y\n", Duration::from_millis(100), stdin_w);
210 if remove_io(&root, dist, false, false, &mut stdin, &mut stdout, &mut stderr)
211 .is_ok()
212 {
213 panic!("Expect to fail to remove a non-exist file without force option.");
214 }
215 input_thread.join().unwrap();
216
217 let (stdin, stdin_w) = pipe().unwrap();
219 let mut stdin = BufReader::new(stdin);
220
221 let input_thread =
222 enter_input_with_delay("y\n", Duration::from_millis(100), stdin_w);
223 remove_io(&root, dist, false, true, &mut stdin, &mut stdout, &mut stderr).unwrap();
224 input_thread.join().unwrap();
225 },
226 {}
227 )
228 }
229}