1#![forbid(unsafe_code)]
2use std::fs;
3use std::io::{self, Error, ErrorKind};
4use std::process::{Command, Output};
5
6const LINUX_DISTROS: [(&str, PackageManager); 8] = [
7 ("alpine", PackageManager::Apk),
8 ("ubuntu", PackageManager::Apt),
9 ("debian", PackageManager::Apt),
10 ("fedora", PackageManager::Dnf),
11 ("rhel", PackageManager::Dnf),
12 ("arch", PackageManager::Pacman),
13 ("gentoo", PackageManager::Portage),
14 ("opensuse", PackageManager::Zypper),
15];
16
17#[derive(Copy, Clone, Debug, Eq, PartialEq)]
18pub enum PackageManager {
19 Apk,
20 Apt,
21 Dnf,
22 Pacman,
23 Portage,
24 Zypper,
25}
26
27impl PackageManager {
28 #[must_use]
29 pub const fn name(&self) -> &'static str {
30 match self {
31 Self::Apk => "apk",
32 Self::Apt => "apt",
33 Self::Dnf => "dnf",
34 Self::Pacman => "pacman",
35 Self::Portage => "portage",
36 Self::Zypper => "zypper",
37 }
38 }
39
40 pub fn package_count(&self) -> io::Result<u64> {
47 self.package_count_with(run_count)
48 }
49
50 fn package_count_with(self, run: fn(&str) -> io::Result<u64>) -> io::Result<u64> {
51 #[allow(clippy::literal_string_with_formatting_args)]
52 match self {
53 Self::Apk => run("apk info | wc -l"),
54 Self::Apt => run("dpkg-query -f '${binary:Package}\\n' -W | wc -l"),
55 Self::Dnf | Self::Zypper => run("rpm -qa | wc -l"),
56 Self::Pacman => run("pacman -Q | wc -l"),
57 Self::Portage => run("qlist -I | wc -l"),
58 }
59 }
60}
61
62pub fn detect_package_manager() -> io::Result<PackageManager> {
70 let os_release = fs::read_to_string("/etc/os-release")?;
71 detect_from_os_release(&os_release)
72}
73
74fn detect_from_os_release(os_release: &str) -> io::Result<PackageManager> {
75 let id = read_key(os_release, "ID");
76 let id_like = read_key(os_release, "ID_LIKE");
77 if id.is_none() && id_like.is_none() {
78 return Err(Error::new(ErrorKind::InvalidData, "missing ID and ID_LIKE"));
79 }
80
81 if let Some(distro) = id
82 && let Some(manager) = lookup(distro)
83 {
84 return Ok(manager);
85 }
86
87 if let Some(distros) = id_like {
88 for distro in distros.split_ascii_whitespace() {
89 if let Some(manager) = lookup(distro) {
90 return Ok(manager);
91 }
92 }
93 }
94
95 Err(Error::new(ErrorKind::InvalidInput, "unknown pkg manager"))
96}
97
98fn lookup(id: &str) -> Option<PackageManager> {
99 LINUX_DISTROS
100 .iter()
101 .find(|(distro, _)| *distro == id)
102 .map(|(_, manager)| *manager)
103}
104
105fn read_key<'a>(os: &'a str, prefix: &str) -> Option<&'a str> {
106 os.lines()
107 .filter_map(|line| line.trim_start().split_once('='))
108 .find(|(key, _)| *key == prefix)
109 .map(|(_, val)| val.trim_matches('"'))
110}
111
112fn run_cmd(cmd: &str) -> io::Result<Output> {
113 Command::new("sh").arg("-c").arg(cmd).output()
114}
115
116fn run_count(cmd: &str) -> io::Result<u64> {
117 run_count_with(cmd, run_cmd)
118}
119
120fn run_count_with(cmd: &str, run: fn(&str) -> io::Result<Output>) -> io::Result<u64> {
121 let output = run(cmd)?;
122 if !output.status.success() {
123 return Err(Error::other("command failed"));
124 }
125
126 let text = std::str::from_utf8(&output.stdout)
127 .map_err(|_| Error::new(ErrorKind::InvalidData, "non-utf8 output"))?;
128 parse_count(text)
129}
130
131fn parse_count(text: &str) -> io::Result<u64> {
132 let trimmed = text.trim();
133 if trimmed.is_empty() {
134 return Err(Error::new(ErrorKind::InvalidData, "empty output"));
135 }
136
137 trimmed
138 .parse::<u64>()
139 .map_err(|_| Error::new(ErrorKind::InvalidData, "invalid count"))
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::os::unix::process::ExitStatusExt;
146 use std::process::ExitStatus;
147
148 #[test]
149 fn supported_distro_count_matches_expected() {
150 assert_eq!(LINUX_DISTROS.len(), 8);
151 }
152
153 #[test]
154 fn detects_pacakge_managers_from_id() {
155 let cases = [
156 ("debian", "apt"),
157 ("fedora", "dnf"),
158 ("arch", "pacman"),
159 ("alpine", "apk"),
160 ("gentoo", "portage"),
161 ];
162
163 for (id, expected) in cases {
164 let sample = format!("NAME=Foo\nID={id}\n");
165 let pm = detect_from_os_release(&sample).expect("should match");
166 assert_eq!(pm.name(), expected);
167 }
168 }
169
170 #[test]
171 fn detects_pacakge_managers_from_id_like() {
172 let cases = [
173 ("almalinux", "rhel centos fedora", "dnf"),
174 ("linuxmint", "ubuntu", "apt"),
175 ("manjaro", "arch", "pacman"),
176 ("opensuse-tumbleweed", "opensuse suse", "zypper"),
177 ];
178
179 for (id, id_like, expected) in cases {
180 let sample = format!("NAME=Foo\nID={id}\nID_LIKE={id_like}\n");
181 let pm = detect_from_os_release(&sample).expect("should match");
182 assert_eq!(pm.name(), expected);
183 }
184 }
185
186 #[test]
187 fn prefers_id_over_id_like() {
188 let sample = "NAME=Foo\nID=ubuntu\nID_LIKE=debian\n";
189 let pm = detect_from_os_release(sample).expect("should match");
190 assert_eq!(pm.name(), "apt");
191 }
192
193 #[test]
194 fn rejects_missing_id_and_id_like() {
195 let sample = "NAME=Foo\n";
196 let err = detect_from_os_release(sample).unwrap_err();
197 assert_eq!(err.kind(), ErrorKind::InvalidData);
198 }
199
200 #[test]
201 fn rejects_unknown_id_like() {
202 let sample = "ID_LIKE=unknown";
203 let err = detect_from_os_release(sample).unwrap_err();
204 assert_eq!(err.kind(), ErrorKind::InvalidInput);
205 }
206
207 #[test]
208 fn rejects_unknown_id() {
209 let sample = "ID=unknown\n";
210 let err = detect_from_os_release(sample).unwrap_err();
211 assert_eq!(err.kind(), ErrorKind::InvalidInput);
212 }
213
214 fn fake_run(cmd: &str) -> io::Result<u64> {
215 match cmd {
216 "apk info | wc -l" => Ok(10),
217 "dpkg-query -f '${binary:Package}\\n' -W | wc -l" => Ok(20),
218 "rpm -qa | wc -l" => Ok(30),
219 "pacman -Q | wc -l" => Ok(40),
220 "qlist -I | wc -l" => Ok(50),
221 _ => Err(Error::new(ErrorKind::InvalidInput, "unknown cmd")),
222 }
223 }
224
225 #[test]
226 fn package_count_uses_expected_commands() {
227 let cases = [
228 (PackageManager::Apk, 10),
229 (PackageManager::Apt, 20),
230 (PackageManager::Dnf, 30),
231 (PackageManager::Pacman, 40),
232 (PackageManager::Portage, 50),
233 (PackageManager::Zypper, 30),
234 ];
235
236 for (pm, expected) in cases {
237 let count = pm.package_count_with(fake_run).expect("count ok");
238 assert_eq!(count, expected);
239 }
240 }
241
242 #[test]
243 fn fake_run_rejects_unknown_command() {
244 let err = fake_run("nope").unwrap_err();
245 assert_eq!(err.kind(), ErrorKind::InvalidInput);
246 }
247
248 #[allow(clippy::unnecessary_wraps)]
249 fn fake_output_ok(_cmd: &str) -> io::Result<Output> {
250 Ok(Output {
251 status: ExitStatus::from_raw(0),
252 stdout: b"42\n".to_vec(),
253 stderr: Vec::new(),
254 })
255 }
256
257 #[allow(clippy::unnecessary_wraps)]
258 fn fake_output_bad(_cmd: &str) -> io::Result<Output> {
259 Ok(Output {
260 status: ExitStatus::from_raw(1),
261 stdout: Vec::new(),
262 stderr: Vec::new(),
263 })
264 }
265
266 fn fake_output_err(_cmd: &str) -> io::Result<Output> {
267 Err(Error::new(ErrorKind::NotFound, "missing cmd"))
268 }
269
270 #[allow(clippy::unnecessary_wraps)]
271 fn fake_output_non_utf8(_cmd: &str) -> io::Result<Output> {
272 Ok(Output {
273 status: ExitStatus::from_raw(0),
274 stdout: vec![0xff, 0xfe, 0xfd],
275 stderr: Vec::new(),
276 })
277 }
278
279 #[test]
280 fn run_count_with_parses_stdout() {
281 let count = run_count_with("ignored", fake_output_ok).expect("count ok");
282 assert_eq!(count, 42);
283 }
284
285 #[test]
286 fn run_count_with_fails_on_status() {
287 let err = run_count_with("ignored", fake_output_bad).unwrap_err();
288 assert_eq!(err.kind(), ErrorKind::Other);
289 }
290
291 #[test]
292 fn run_count_with_rejects_non_utf8() {
293 let err = run_count_with("ignored", fake_output_non_utf8).unwrap_err();
294 assert_eq!(err.kind(), ErrorKind::InvalidData);
295 }
296
297 #[test]
298 fn run_count_with_propagates_runner_error() {
299 let err = run_count_with("ignored", fake_output_err).unwrap_err();
300 assert_eq!(err.kind(), ErrorKind::NotFound);
301 }
302
303 #[test]
304 fn run_count_reports_missing_command_failure() {
305 let err = run_count("cmd-that-should-not-exist").unwrap_err();
306 assert_eq!(err.kind(), ErrorKind::Other);
307 }
308
309 #[test]
310 fn parse_count_rejects_empty() {
311 let err = parse_count(" ").unwrap_err();
312 assert_eq!(err.kind(), ErrorKind::InvalidData);
313 }
314
315 #[test]
316 fn parse_count_rejects_invalid() {
317 let err = parse_count("nope").unwrap_err();
318 assert_eq!(err.kind(), ErrorKind::InvalidData);
319 }
320
321 #[test]
322 fn parse_count_accepts_valid() {
323 let count = parse_count(" 123 ").expect("count ok");
324 assert_eq!(count, 123);
325 }
326}