letmein_systemd/
lib.rs

1// -*- coding: utf-8 -*-
2//
3// Copyright (C) 2024 Michael Büsch <m@bues.ch>
4//
5// Licensed under the Apache License version 2.0
6// or the MIT license, at your option.
7// SPDX-License-Identifier: Apache-2.0 OR MIT
8
9//! This crate is an abstraction of the `systemd` interfaces needed by `letmein`.
10
11#[cfg(not(any(target_os = "linux", target_os = "android")))]
12std::compile_error!("letmeind server and letmein-systemd do not support non-Linux platforms.");
13
14use anyhow::{self as ah, format_err as err, Context as _};
15
16#[cfg(any(feature = "tcp", feature = "unix"))]
17use std::{
18    mem::size_of_val,
19    os::fd::{FromRawFd as _, RawFd},
20};
21
22#[cfg(feature = "udp")]
23use std::net::UdpSocket;
24
25#[cfg(feature = "tcp")]
26use std::net::TcpListener;
27
28#[cfg(feature = "unix")]
29use std::os::unix::net::UnixListener;
30
31#[cfg(any(feature = "udp", feature = "tcp"))]
32const INET46: [Option<libc::c_int>; 2] = [Some(libc::AF_INET), Some(libc::AF_INET6)];
33
34/// Check if the passed raw `fd` is a socket.
35#[cfg(any(feature = "udp", feature = "tcp", feature = "unix"))]
36fn is_socket(fd: RawFd) -> bool {
37    // SAFETY: Initializing `libc::stat64` structure with zero is an allowed pattern.
38    let mut stat: libc::stat64 = unsafe { std::mem::zeroed() };
39    // SAFETY: The `fd` is valid and `stat` is initialized and valid.
40    let ret = unsafe { libc::fstat64(fd, &mut stat) };
41    if ret == 0 {
42        const S_IFMT: libc::mode_t = libc::S_IFMT as libc::mode_t;
43        const S_IFSOCK: libc::mode_t = libc::S_IFSOCK as libc::mode_t;
44        (stat.st_mode as libc::mode_t & S_IFMT) == S_IFSOCK
45    } else {
46        false
47    }
48}
49
50/// Get the socket type of the passed socket `fd`.
51///
52/// SAFETY: The passed `fd` must be a socket `fd`.
53#[cfg(any(feature = "udp", feature = "tcp", feature = "unix"))]
54unsafe fn get_socket_type(fd: RawFd) -> Option<libc::c_int> {
55    let mut sotype: libc::c_int = 0;
56    let mut len: libc::socklen_t = size_of_val(&sotype) as _;
57    // SAFETY: The `fd` is valid, `sotype` and `len` are initialized and valid.
58    let ret = unsafe {
59        libc::getsockopt(
60            fd,
61            libc::SOL_SOCKET,
62            libc::SO_TYPE,
63            &mut sotype as *mut _ as _,
64            &mut len,
65        )
66    };
67    if ret == 0 && len >= size_of_val(&sotype) as _ {
68        Some(sotype)
69    } else {
70        None
71    }
72}
73
74/// Get the socket family of the passed socket `fd`.
75///
76/// SAFETY: The passed `fd` must be a socket `fd`.
77#[cfg(any(feature = "udp", feature = "tcp", feature = "unix"))]
78unsafe fn get_socket_family(fd: RawFd) -> Option<libc::c_int> {
79    // SAFETY: Initializing `libc::sockaddr` structure with zero is an allowed pattern.
80    let mut saddr: libc::sockaddr = unsafe { std::mem::zeroed() };
81    let mut len: libc::socklen_t = size_of_val(&saddr) as _;
82    // SAFETY: The `fd` is valid, `saddr` and `len` are initialized and valid.
83    let ret = unsafe { libc::getsockname(fd, &mut saddr, &mut len) };
84    if ret == 0 && len >= size_of_val(&saddr) as _ {
85        Some(saddr.sa_family.into())
86    } else {
87        None
88    }
89}
90
91#[cfg(feature = "udp")]
92fn is_udp_socket(fd: RawFd) -> bool {
93    // SAFETY: Check if `fd` is a socket before using the socket functions.
94    unsafe {
95        is_socket(fd)
96            && get_socket_type(fd) == Some(libc::SOCK_DGRAM)
97            && INET46.contains(&get_socket_family(fd))
98    }
99}
100
101#[cfg(feature = "tcp")]
102fn is_tcp_socket(fd: RawFd) -> bool {
103    // SAFETY: Check if `fd` is a socket before using the socket functions.
104    unsafe {
105        is_socket(fd)
106            && get_socket_type(fd) == Some(libc::SOCK_STREAM)
107            && INET46.contains(&get_socket_family(fd))
108    }
109}
110
111#[cfg(feature = "unix")]
112fn is_unix_socket(fd: RawFd) -> bool {
113    // SAFETY: Check if `fd` is a socket before using the socket functions.
114    unsafe {
115        is_socket(fd)
116            && get_socket_type(fd) == Some(libc::SOCK_STREAM)
117            && get_socket_family(fd) == Some(libc::AF_UNIX)
118    }
119}
120
121/// A socket that systemd handed us over.
122#[derive(Debug)]
123#[non_exhaustive]
124pub enum SystemdSocket {
125    /// UDP socket.
126    #[cfg(feature = "udp")]
127    Udp(UdpSocket),
128
129    /// TCP socket.
130    #[cfg(feature = "tcp")]
131    Tcp(TcpListener),
132
133    /// Unix socket.
134    #[cfg(feature = "unix")]
135    Unix(UnixListener),
136}
137
138impl SystemdSocket {
139    /// Get all sockets from systemd.
140    ///
141    /// All environment variables related to this operation will be cleared.
142    #[allow(unused_mut)]
143    pub fn get_all() -> ah::Result<Vec<SystemdSocket>> {
144        let mut sockets = vec![];
145        if sd_notify::booted().unwrap_or(false) {
146            for fd in sd_notify::listen_fds().context("Systemd listen_fds")? {
147                #[cfg(feature = "udp")]
148                if is_udp_socket(fd) {
149                    // SAFETY:
150                    // The fd from systemd is good and lives for the lifetime of the program.
151                    let sock = unsafe { UdpSocket::from_raw_fd(fd) };
152                    sockets.push(SystemdSocket::Udp(sock));
153                    continue;
154                }
155
156                #[cfg(feature = "tcp")]
157                if is_tcp_socket(fd) {
158                    // SAFETY:
159                    // The fd from systemd is good and lives for the lifetime of the program.
160                    let sock = unsafe { TcpListener::from_raw_fd(fd) };
161                    sockets.push(SystemdSocket::Tcp(sock));
162                    continue;
163                }
164
165                #[cfg(feature = "unix")]
166                if is_unix_socket(fd) {
167                    // SAFETY:
168                    // The fd from systemd is good and lives for the lifetime of the program.
169                    let sock = unsafe { UnixListener::from_raw_fd(fd) };
170                    sockets.push(SystemdSocket::Unix(sock));
171                    continue;
172                }
173
174                let _ = fd;
175                return Err(err!("Received unknown socket from systemd"));
176            }
177        }
178        Ok(sockets)
179    }
180}
181
182/// Notify ready-status to systemd.
183///
184/// All environment variables related to this operation will be cleared.
185pub fn systemd_notify_ready() -> ah::Result<()> {
186    sd_notify::notify(true, &[sd_notify::NotifyState::Ready])?;
187    Ok(())
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_systemd() {
196        assert!(SystemdSocket::get_all().unwrap().is_empty());
197
198        systemd_notify_ready().unwrap();
199    }
200}
201
202// vim: ts=4 sw=4 expandtab