1use core::str;
4use std::process::Command;
5
6use anyhow::{bail, Context, Result};
7use log::error;
8use tokio_util::sync::CancellationToken;
9use tracing::info;
10use zbus::{proxy, zvariant, Connection};
11
12#[derive(Debug, Clone)]
13pub struct UnitWithStatus {
14 pub name: String, pub scope: UnitScope, pub description: String, pub file_path: Option<Result<String, String>>, pub load_state: String, pub activation_state: String,
24 pub sub_state: String,
26
27 pub enablement_state: Option<String>,
30 }
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub enum UnitScope {
40 Global,
41 User,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct UnitId {
47 pub name: String,
48 pub scope: UnitScope,
49}
50
51impl UnitWithStatus {
52 pub fn is_active(&self) -> bool {
53 self.activation_state == "active"
54 }
55
56 pub fn is_failed(&self) -> bool {
57 self.activation_state == "failed"
58 }
59
60 pub fn is_not_found(&self) -> bool {
61 self.load_state == "not-found"
62 }
63
64 pub fn is_enabled(&self) -> bool {
65 self.load_state == "loaded" && self.activation_state == "active"
66 }
67
68 pub fn short_name(&self) -> &str {
69 if self.name.ends_with(".service") {
70 &self.name[..self.name.len() - 8]
71 } else {
72 &self.name
73 }
74 }
75
76 pub fn id(&self) -> UnitId {
78 UnitId { name: self.name.clone(), scope: self.scope }
79 }
80
81 pub fn update(&mut self, other: UnitWithStatus) {
83 self.description = other.description;
84 self.load_state = other.load_state;
85 self.activation_state = other.activation_state;
86 self.sub_state = other.sub_state;
87 }
88}
89
90type RawUnit =
91 (String, String, String, String, String, String, zvariant::OwnedObjectPath, u32, String, zvariant::OwnedObjectPath);
92
93fn to_unit_status(raw_unit: RawUnit, scope: UnitScope) -> UnitWithStatus {
94 let (name, description, load_state, active_state, sub_state, _followed, _path, _job_id, _job_type, _job_path) =
95 raw_unit;
96
97 UnitWithStatus {
98 name,
99 scope,
100 description,
101 file_path: None,
102 enablement_state: None,
103 load_state,
104 activation_state: active_state,
105 sub_state,
106 }
107}
108
109#[derive(Clone, Copy, Default, Debug)]
111pub enum Scope {
112 Global,
113 User,
114 #[default]
115 All,
116}
117
118pub async fn get_all_services(scope: Scope, services: &[String]) -> Result<Vec<UnitWithStatus>> {
120 let start = std::time::Instant::now();
121
122 let mut units = vec![];
123
124 let is_root = nix::unistd::geteuid().is_root();
125
126 match scope {
127 Scope::Global => {
128 let system_units = get_services(UnitScope::Global, services).await?;
129 units.extend(system_units);
130 },
131 Scope::User => {
132 let user_units = get_services(UnitScope::User, services).await?;
133 units.extend(user_units);
134 },
135 Scope::All => {
136 let (system_units, user_units) =
137 tokio::join!(get_services(UnitScope::Global, services), get_services(UnitScope::User, services));
138 units.extend(system_units?);
139
140 if let Ok(user_units) = user_units {
142 units.extend(user_units);
143 } else if is_root {
144 error!("Failed to get user units, ignoring because we're running as root")
145 } else {
146 user_units?;
147 }
148 },
149 }
150
151 units.sort_by(|a, b| a.name.to_lowercase().cmp(&b.name.to_lowercase()));
153
154 info!("Loaded systemd services in {:?}", start.elapsed());
155
156 Ok(units)
157}
158
159async fn get_services(scope: UnitScope, services: &[String]) -> Result<Vec<UnitWithStatus>, anyhow::Error> {
160 let connection = get_connection(scope).await?;
161 let manager_proxy = ManagerProxy::new(&connection).await?;
162 let units = manager_proxy.list_units_by_patterns(vec![], services.to_vec()).await?;
163 let units: Vec<_> = units.into_iter().map(|u| to_unit_status(u, scope)).collect();
164 Ok(units)
165}
166
167pub fn get_unit_file_location(service: &UnitId) -> Result<String> {
168 let mut args = vec!["--quiet", "show", "-P", "FragmentPath"];
170 args.push(&service.name);
171
172 if service.scope == UnitScope::User {
173 args.insert(0, "--user");
174 }
175
176 let output = Command::new("systemctl").args(&args).output()?;
177
178 if output.status.success() {
179 let path = str::from_utf8(&output.stdout)?.trim();
180 if path.is_empty() {
181 bail!("No unit file found for {}", service.name);
182 }
183 Ok(path.trim().to_string())
184 } else {
185 let stderr = String::from_utf8(output.stderr)?;
186 bail!(stderr);
187 }
188}
189
190pub async fn start_service(service: UnitId, cancel_token: CancellationToken) -> Result<()> {
191 async fn start_service(service: UnitId) -> Result<()> {
192 let connection = get_connection(service.scope).await?;
193 let manager_proxy = ManagerProxy::new(&connection).await?;
194 manager_proxy.start_unit(service.name.clone(), "replace".into()).await?;
195 Ok(())
196 }
197
198 tokio::select! {
200 _ = cancel_token.cancelled() => {
201 anyhow::bail!("cancelled");
202 }
203 result = start_service(service) => {
204 result
205 }
206 }
207}
208
209pub async fn stop_service(service: UnitId, cancel_token: CancellationToken) -> Result<()> {
210 async fn stop_service(service: UnitId) -> Result<()> {
211 let connection = get_connection(service.scope).await?;
212 let manager_proxy = ManagerProxy::new(&connection).await?;
213 manager_proxy.stop_unit(service.name, "replace".into()).await?;
214 Ok(())
215 }
216
217 tokio::select! {
219 _ = cancel_token.cancelled() => {
220 anyhow::bail!("cancelled");
221 }
222 result = stop_service(service) => {
223 result
224 }
225 }
226}
227
228pub async fn reload(scope: UnitScope, cancel_token: CancellationToken) -> Result<()> {
229 async fn reload_(scope: UnitScope) -> Result<()> {
230 let connection = get_connection(scope).await?;
231 let manager_proxy: ManagerProxy<'_> = ManagerProxy::new(&connection).await?;
232 let error_message = match scope {
233 UnitScope::Global => "Failed to reload units, probably because superuser permissions are needed. Try running `sudo systemctl daemon-reload`",
234 UnitScope::User => "Failed to reload units. Try running `systemctl --user daemon-reload`",
235 };
236 manager_proxy.reload().await.context(error_message)?;
237 Ok(())
238 }
239
240 tokio::select! {
242 _ = cancel_token.cancelled() => {
243 anyhow::bail!("cancelled");
244 }
245 result = reload_(scope) => {
246 result
247 }
248 }
249}
250
251async fn get_connection(scope: UnitScope) -> Result<Connection, anyhow::Error> {
252 match scope {
253 UnitScope::Global => Ok(Connection::system().await?),
254 UnitScope::User => Ok(Connection::session().await?),
255 }
256}
257
258pub async fn restart_service(service: UnitId, cancel_token: CancellationToken) -> Result<()> {
259 async fn restart(service: UnitId) -> Result<()> {
260 let connection = get_connection(service.scope).await?;
261 let manager_proxy = ManagerProxy::new(&connection).await?;
262 manager_proxy.restart_unit(service.name, "replace".into()).await?;
263 Ok(())
264 }
265
266 tokio::select! {
268 _ = cancel_token.cancelled() => {
269 anyhow::bail!("cancelled");
271 }
272 result = restart(service) => {
273 result
274 }
275 }
276}
277
278pub async fn sleep_test(_service: String, cancel_token: CancellationToken) -> Result<()> {
280 tokio::select! {
282 _ = cancel_token.cancelled() => {
283 anyhow::bail!("cancelled");
285 }
286 _ = tokio::time::sleep(std::time::Duration::from_secs(2)) => {
287 Ok(())
288 }
289 }
290}
291
292pub async fn kill_service(service: UnitId, signal: String, cancel_token: CancellationToken) -> Result<()> {
293 async fn kill(service: UnitId, signal: String) -> Result<()> {
294 let mut args = vec!["kill", "--signal", &signal];
295 if service.scope == UnitScope::User {
296 args.push("--user");
297 }
298 args.push(&service.name);
299
300 let output = Command::new("systemctl").args(&args).output()?;
301
302 if output.status.success() {
303 info!("Successfully sent signal {} to srvice {}", signal, service.name);
304 Ok(())
305 } else {
306 let stderr = String::from_utf8(output.stderr)?;
307 bail!("Failed to send signal {} to service {}: {}", signal, service.name, stderr);
308 }
309 }
310
311 tokio::select! {
312 _ = cancel_token.cancelled() => {
313 bail!("cancelled");
314 }
315 result = kill(service, signal) => {
316 result
317 }
318 }
319}
320
321#[proxy(
324 interface = "org.freedesktop.systemd1.Manager",
325 default_service = "org.freedesktop.systemd1",
326 default_path = "/org/freedesktop/systemd1",
327 gen_blocking = false
328)]
329pub trait Manager {
330 #[zbus(name = "StartUnit")]
332 fn start_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
333
334 #[zbus(name = "StopUnit")]
336 fn stop_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
337
338 #[zbus(name = "ReloadUnit")]
340 fn reload_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
341
342 #[zbus(name = "RestartUnit")]
344 fn restart_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
345
346 #[zbus(name = "EnableUnitFiles")]
348 fn enable_unit_files(
349 &self,
350 files: Vec<String>,
351 runtime: bool,
352 force: bool,
353 ) -> zbus::Result<(bool, Vec<(String, String, String)>)>;
354
355 #[zbus(name = "DisableUnitFiles")]
357 fn disable_unit_files(&self, files: Vec<String>, runtime: bool) -> zbus::Result<Vec<(String, String, String)>>;
358
359 #[zbus(name = "ListUnits")]
361 fn list_units(
362 &self,
363 ) -> zbus::Result<
364 Vec<(
365 String,
366 String,
367 String,
368 String,
369 String,
370 String,
371 zvariant::OwnedObjectPath,
372 u32,
373 String,
374 zvariant::OwnedObjectPath,
375 )>,
376 >;
377
378 #[zbus(name = "ListUnitsByPatterns")]
380 fn list_units_by_patterns(
381 &self,
382 states: Vec<String>,
383 patterns: Vec<String>,
384 ) -> zbus::Result<
385 Vec<(
386 String,
387 String,
388 String,
389 String,
390 String,
391 String,
392 zvariant::OwnedObjectPath,
393 u32,
394 String,
395 zvariant::OwnedObjectPath,
396 )>,
397 >;
398
399 #[zbus(name = "Reload")]
401 fn reload(&self) -> zbus::Result<()>;
402}
403
404#[proxy(
407 interface = "org.freedesktop.systemd1.Unit",
408 default_service = "org.freedesktop.systemd1",
409 assume_defaults = false,
410 gen_blocking = false
411)]
412pub trait Unit {
413 #[zbus(property)]
415 fn active_state(&self) -> zbus::Result<String>;
416
417 #[zbus(property)]
419 fn load_state(&self) -> zbus::Result<String>;
420
421 #[zbus(property)]
423 fn unit_file_state(&self) -> zbus::Result<String>;
424}
425
426#[proxy(
429 interface = "org.freedesktop.systemd1.Service",
430 default_service = "org.freedesktop.systemd1",
431 assume_defaults = false,
432 gen_blocking = false
433)]
434trait Service {
435 #[zbus(property, name = "MainPID")]
437 fn main_pid(&self) -> zbus::Result<u32>;
438}
439
440pub async fn get_active_state(connection: &Connection, full_service_name: &str) -> String {
450 let object_path = get_unit_path(full_service_name);
451
452 match zvariant::ObjectPath::try_from(object_path) {
453 Ok(path) => {
454 let unit_proxy = UnitProxy::new(connection, path).await.unwrap();
455 unit_proxy.active_state().await.unwrap_or("invalid-unit-path".into())
456 },
457 Err(_) => "invalid-unit-path".to_string(),
458 }
459}
460
461pub async fn get_unit_file_state(connection: &Connection, full_service_name: &str) -> String {
471 let object_path = get_unit_path(full_service_name);
472
473 match zvariant::ObjectPath::try_from(object_path) {
474 Ok(path) => {
475 let unit_proxy = UnitProxy::new(connection, path).await.unwrap();
476 unit_proxy.unit_file_state().await.unwrap_or("invalid-unit-path".into())
477 },
478 Err(_) => "invalid-unit-path".to_string(),
479 }
480}
481
482pub async fn get_main_pid(connection: &Connection, full_service_name: &str) -> Result<u32, zbus::Error> {
490 let object_path = get_unit_path(full_service_name);
491
492 let validated_object_path = zvariant::ObjectPath::try_from(object_path).unwrap();
493
494 let service_proxy = ServiceProxy::new(connection, validated_object_path).await.unwrap();
495 service_proxy.main_pid().await
496}
497
498fn encode_as_dbus_object_path(input_string: &str) -> String {
505 input_string
506 .chars()
507 .map(|c| if c.is_ascii_alphanumeric() || c == '/' || c == '_' { c.to_string() } else { format!("_{:x}", c as u32) })
508 .collect()
509}
510
511pub fn get_unit_path(full_service_name: &str) -> String {
518 format!("/org/freedesktop/systemd1/unit/{}", encode_as_dbus_object_path(full_service_name))
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_get_unit_path() {
527 assert_eq!(get_unit_path("test.service"), "/org/freedesktop/systemd1/unit/test_2eservice");
528 }
529
530 #[test]
531 fn test_encode_as_dbus_object_path() {
532 assert_eq!(encode_as_dbus_object_path("test.service"), "test_2eservice");
533 assert_eq!(encode_as_dbus_object_path("test-with-hyphen.service"), "test_2dwith_2dhyphen_2eservice");
534 }
535}