Skip to main content

bye/
lib.rs

1#![deny(unsafe_op_in_unsafe_fn)]
2#![warn(missing_docs, clippy::pedantic)]
3//! bye: graceful shutdown and USR1 zero-downtime upgrade helpers.
4//!
5//! - `Bye` manages a `CancellationToken` broadcast + a `TaskTracker`.
6//! - `Bye::new_with_signals()` listens for TERM/INT/QUIT -> drain, USR1 -> fork+exec self and wait for child `ready()`.
7//! - `ready()` tells the parent (or your process manager) that the service is ready.
8//! - `systemd_tcp_listener(port)` uses socket activation if available.
9
10use std::{
11    env::VarError,
12    ffi::{CString, OsStr},
13    fmt::Display,
14    net::{IpAddr, Ipv4Addr, SocketAddr},
15    os::{
16        fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd},
17        unix::ffi::OsStrExt,
18    },
19    path::{Path, PathBuf},
20    sync::{atomic::AtomicU32, LazyLock},
21    time::Duration,
22};
23
24use bye::Error;
25use nix::{
26    errno::Errno,
27    fcntl::{FcntlArg, FdFlag, OFlag, fcntl},
28    sys::{socket::getsockname, wait::waitpid},
29    unistd::{ForkResult, execve, fork, pipe2, read},
30};
31use tokio::{io::unix::AsyncFd, net::TcpListener};
32
33pub use bye::Bye;
34pub use tokio_util::sync::CancellationToken;
35
36#[cfg(feature = "tracing")]
37use tracing::{error, info, warn};
38
39mod bye;
40
41const READY_ENV: &str = "UPGRADE_FD";
42static READY_LAST_PID: AtomicU32 = AtomicU32::new(0);
43
44#[derive(Debug)]
45struct UpgradeUsr1 {
46    exe_path: CString,
47    args: Vec<CString>,
48    env: Vec<CString>,
49}
50
51impl UpgradeUsr1 {
52    pub fn new(exe_path: &Path) -> bye::Result<Self> {
53        let exe_path = CString::new(exe_path.as_os_str().as_bytes())?;
54        let args = std::env::args_os()
55            .map(|arg| CString::new(arg.as_bytes()))
56            .collect::<Result<Vec<_>, _>>()?;
57
58        let env = std::env::vars_os()
59            .map(|(key, value)| {
60                let mut kv = Vec::new();
61                kv.extend(key.as_bytes());
62                kv.push(b'=');
63                kv.extend(value.as_bytes());
64                CString::new(kv)
65            })
66            .collect::<Result<Vec<_>, _>>()?;
67
68        Ok(Self {
69            exe_path,
70            args,
71            env,
72        })
73    }
74
75    pub async fn upgrade(&self, timeout: Option<Duration>) -> bye::Result<bool> {
76        fork_and_exec(&self.exe_path, &self.args, &self.env, timeout).await
77    }
78}
79
80fn cstr_kv_eq_ignore_ascii(e: &CString, key: &[u8]) -> bool {
81    let bytes = e.as_bytes();
82    if let Some(eq) = memchr::memchr(b'=', bytes) {
83        bytes[..eq].eq_ignore_ascii_case(key)
84    } else {
85        false
86    }
87}
88
89fn replace_or_push_env<D: Display>(env: &mut Vec<CString>, key: &str, val: D) -> bye::Result<()> {
90    let kbytes = key.as_bytes();
91    if let Some(idx) = env.iter().position(|e| cstr_kv_eq_ignore_ascii(e, kbytes)) {
92        env[idx] = CString::new(format!("{key}={val}"))?;
93    } else {
94        env.push(CString::new(format!("{key}={val}"))?);
95    }
96    Ok(())
97}
98
99fn get_env_u32(env: &[CString], key: &str) -> Option<u32> {
100    let kbytes = key.as_bytes();
101    for e in env {
102        let bytes = e.as_bytes();
103        if let Some(eq) = memchr::memchr(b'=', bytes) {
104            if bytes[..eq].eq_ignore_ascii_case(kbytes) {
105                if let Ok(s) = std::str::from_utf8(&bytes[eq + 1..]) {
106                    if let Ok(n) = s.parse::<u32>() {
107                        return Some(n);
108                    }
109                }
110            }
111        }
112    }
113    None
114}
115
116async fn fork_and_exec(
117    path: &CString,
118    args: &[CString],
119    env: &[CString],
120    timeout: Option<Duration>,
121) -> bye::Result<bool> {
122    let (read_fd, write_fd) = pipe2(OFlag::O_CLOEXEC).map_err(Error::Pipe2)?;
123
124    clear_cloexec(&write_fd)?;
125    set_nonblocking(&read_fd)?;
126
127    let mut env_with_fd = Vec::with_capacity(env.len() + 1);
128    env_with_fd.extend(
129        env.iter()
130            .filter(|e| !cstr_kv_eq_ignore_ascii(e, READY_ENV.as_bytes()))
131            .cloned(),
132    );
133    env_with_fd.push(CString::new(format!(
134        "{}={}",
135        READY_ENV,
136        write_fd.as_raw_fd()
137    ))?);
138
139    // cstr for execve with old argv
140    let argv_ref: Vec<&_> = args.iter().map(std::ffi::CString::as_c_str).collect();
141
142    match unsafe { fork().map_err(Error::Fork)? } {
143        ForkResult::Parent { child } => {
144            drop(env_with_fd);
145
146            // make sure the write fd is closed in the parent
147            drop(write_fd);
148
149            let af = AsyncFd::new(read_fd)?;
150            // could have a timeout here if desired
151
152            let fut = async move {
153                let mut buf = [0; 1];
154                loop {
155                    let mut guard = af.readable().await?;
156                    match read(&af, &mut buf) {
157                        Ok(0) => return Ok(false),
158                        Ok(_) => return Ok(true),
159                        Err(Errno::EAGAIN) => {
160                            // continue waiting
161                            guard.clear_ready();
162                        }
163                        Err(Errno::EINTR) => {
164                            // interrupted, continue waiting
165                            guard.clear_ready();
166                        }
167                        Err(e) => {
168                            return Err(Error::Nix(e));
169                        }
170                    }
171                }
172            };
173
174            let result = if let Some(dur) = timeout {
175                tokio::time::timeout(dur, fut).await
176            } else {
177                Ok(fut.await)
178            };
179
180            match result {
181                Ok(Ok(true)) => Ok(true),
182                Ok(Ok(false)) => Ok(false),
183                Ok(Err(e)) => Err(e),
184                Err(_) => {
185                    use nix::sys::signal::{Signal, kill};
186                    kill(child, Signal::SIGKILL).map_err(|e| Error::KillChild {
187                        pid: child.into(),
188                        source: e,
189                    })?;
190                    waitpid(child, None).map_err(Error::WaitPid)?;
191                    Err(Error::ChildTimeout)
192                }
193            }
194        }
195        ForkResult::Child => {
196            drop(read_fd);
197
198            if let Some(nfds) = get_env_u32(&env_with_fd, "LISTEN_FDS") {
199                if nfds > 0 {
200                    let child_pid = std::process::id();
201                    replace_or_push_env(&mut env_with_fd, "LISTEN_PID", child_pid).ok();
202
203                    for i in 0..nfds {
204                        #[allow(clippy::cast_possible_wrap)]
205                        let fd = (3 + i) as i32;
206                        let borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
207                        if let Err(e) = clear_cloexec(&borrowed) {
208                            #[cfg(feature = "tracing")]
209                            error!("failed to clear cloexec on fd {}: {}", fd, e);
210                            std::process::exit(127);
211                        }
212                    }
213                }
214            }
215
216            let envp: Vec<&_> = env_with_fd.iter().map(std::ffi::CString::as_c_str).collect();
217
218            execve(path, &argv_ref, &envp).map_err(Error::Execve)?;
219            std::process::exit(127);
220        }
221    }
222}
223
224fn parse_fd_from_env(v: &OsStr) -> bye::Result<i32> {
225    let s = std::str::from_utf8(v.as_bytes()).map_err(|e| Error::EnvUtf8 {
226        key: "UPGRADE_FD",
227        source: e,
228    })?;
229    let fd = s.parse::<i32>().map_err(|e| Error::EnvParse {
230        key: "UPGRADE_FD",
231        source: e,
232    })?;
233    if fd < 0 {
234        return Err(Error::InvalidUpgradeFd);
235    }
236    Ok(fd)
237}
238
239/// Tries to get the PID file path from the `PIDFILE` environment variable.
240/// # Errors
241/// - `VarError::NotPresent` if the `PIDFILE` environment variable is not set.
242/// - `VarError::NotUnicode` if the `PIDFILE` environment variable is not valid Unicode.
243/// - `VarError::Empty` if the `PIDFILE` environment variable is set but empty.
244pub fn try_pid_file() -> Result<PathBuf, VarError> {
245    Ok(std::env::var("PIDFILE")?.into())
246}
247
248static SYSTEMD_PORTS: LazyLock<Vec<u16>> = std::sync::LazyLock::new(|| {
249    compute_systemd_ports().unwrap_or_else(|e| {
250        #[cfg(feature = "tracing")]
251        error!("Failed to compute systemd ports: {}", e);
252        vec![]
253    })
254});
255
256/// Returns the list of ports inherited from systemd socket activation.
257/// This is computed once and cached for the lifetime of the program.
258#[must_use]
259pub fn systemd_ports() -> &'static [u16] {
260    &SYSTEMD_PORTS
261}
262
263fn compute_systemd_ports() -> bye::Result<Vec<u16>> {
264    let listen_fds = std::env::var("LISTEN_FDS")
265        .unwrap_or("0".to_string())
266        .parse::<u32>()
267        .map_err(|e| Error::EnvParse {
268            key: "LISTEN_FDS",
269            source: e,
270        })?;
271
272    let listen_pid = std::env::var("LISTEN_PID")
273        .unwrap_or("0".to_string())
274        .parse::<u32>()
275        .map_err(|e| Error::EnvParse {
276            key: "LISTEN_PID",
277            source: e,
278        })?;
279
280    if listen_fds == 0 || listen_pid != std::process::id() {
281        return Err(Error::SystemdActivation);
282    }
283
284    let mut ports = Vec::with_capacity(listen_fds as usize);
285    for i in 0..listen_fds {
286        #[allow(clippy::cast_possible_wrap)]
287        let fd = (3 + i) as i32;
288
289        let port = match getsockname::<nix::sys::socket::SockaddrStorage>(fd) {
290            Ok(v) => v
291                .as_sockaddr_in()
292                .map(nix::sys::socket::SockaddrIn::port)
293                .or_else(|| v.as_sockaddr_in6().map(nix::sys::socket::SockaddrIn6::port))
294                .unwrap_or(0),
295            Err(e) => {
296                return Err(Error::Sockname(e));
297            }
298        };
299        if port != 0 {
300            ports.push(port);
301        }
302    }
303
304    Ok(ports)
305}
306
307/// Creates a TCP listener that uses systemd socket activation if available.
308/// If no systemd socket is found for the given port, it falls back to binding a new socket.
309/// # Errors
310/// - If the systemd socket activation is enabled but no socket is found for the given port.
311/// - If binding a new socket fails.
312pub async fn systemd_tcp_listener(port: u16) -> bye::Result<TcpListener> {
313    let systemd_ports = SYSTEMD_PORTS.iter().position(|&p| p == port);
314    let listener = if let Some(systemd_port) = systemd_ports {
315        #[allow(clippy::cast_possible_truncation)]
316        #[allow(clippy::cast_possible_wrap)]
317        let fd = (3 + systemd_port) as i32;
318        #[cfg(feature = "tracing")]
319        info!("using systemd socket fd: {}", fd);
320        let raw_listener = unsafe { std::net::TcpListener::from_raw_fd(fd) };
321        raw_listener.set_nonblocking(true)?;
322        TcpListener::from_std(raw_listener)?
323    } else {
324        #[cfg(feature = "tracing")]
325        warn!("no systemd socket found for port {}", port);
326        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
327        TcpListener::bind(addr).await?
328    };
329
330    Ok(listener)
331}
332
333/// Notifies the system that the service is ready.
334/// This function should be called once the service is fully initialized and ready to accept requests.
335/// # Errors
336/// - If the `UPGRADE_FD` environment variable is set but invalid.
337/// - If writing to the `UPGRADE_FD` fails.
338/// - If writing to the PID file (if `PIDFILE` is set) fails.
339pub fn ready() -> bye::Result<()> {
340    let pid = std::process::id();
341    let prev = READY_LAST_PID.swap(pid, std::sync::atomic::Ordering::AcqRel);
342
343    if prev == pid {
344        return Ok(());
345    }
346
347    if let Some(val) = std::env::var_os(READY_ENV) {
348        let fd = parse_fd_from_env(&val)?;
349        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
350        fcntl(&fd, FcntlArg::F_GETFD).map_err(|e| Error::Fcntl {
351            op: "F_GETFD",
352            source: e,
353        })?;
354
355        loop {
356            let n = nix::unistd::write(&fd, &[1u8]);
357            match n {
358                Ok(_) => break,
359                Err(Errno::EINTR | Errno::EAGAIN) => {}
360                Err(Errno::EPIPE) => {
361                    #[cfg(feature = "tracing")]
362                    warn!("UPGRADE_FD is closed, ignoring write");
363                    break;
364                }
365                Err(e) => {
366                    return Err(Error::NotifyWrite(e));
367                }
368            }
369        }
370    }
371
372    if let Ok(pid_file) = try_pid_file() {
373        let pid = std::process::id();
374        std::fs::write(pid_file, pid.to_string())?;
375    }
376
377    Ok(())
378}
379
380fn clear_cloexec<F: AsFd>(fd: &F) -> bye::Result<()> {
381    let getfd = fcntl(fd, FcntlArg::F_GETFD).map_err(|e| Error::Fcntl {
382        op: "F_GETFD",
383        source: e,
384    })?;
385    let flags = FdFlag::from_bits_truncate(getfd);
386    let new_flags = flags.difference(FdFlag::FD_CLOEXEC);
387    fcntl(fd, FcntlArg::F_SETFD(new_flags)).map_err(|e| Error::Fcntl {
388        op: "F_SETFD",
389        source: e,
390    })?;
391    Ok(())
392}
393
394fn set_nonblocking(fd: &OwnedFd) -> bye::Result<()> {
395    let getfl = fcntl(fd, FcntlArg::F_GETFL).map_err(|e| Error::Fcntl {
396        op: "F_GETFL",
397        source: e,
398    })?;
399    let flags = OFlag::from_bits_truncate(getfl);
400    let new_flags = flags | OFlag::O_NONBLOCK;
401    fcntl(fd, FcntlArg::F_SETFL(new_flags)).map_err(|e| Error::Fcntl {
402        op: "F_SETFL",
403        source: e,
404    })?;
405    Ok(())
406}