1#![deny(unsafe_op_in_unsafe_fn)]
2#![warn(missing_docs, clippy::pedantic)]
3use std::{
11 env::VarError,
12 ffi::{CString, OsStr},
13 fmt::Display,
14 net::{IpAddr, Ipv4Addr, SocketAddr},
15 os::{
16 fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd},
17 unix::ffi::OsStrExt,
18 },
19 path::{Path, PathBuf},
20 sync::{atomic::AtomicU32, LazyLock},
21 time::Duration,
22};
23
24use bye::Error;
25use nix::{
26 errno::Errno,
27 fcntl::{FcntlArg, FdFlag, OFlag, fcntl},
28 sys::{socket::getsockname, wait::waitpid},
29 unistd::{ForkResult, execve, fork, pipe2, read},
30};
31use tokio::{io::unix::AsyncFd, net::TcpListener};
32
33pub use bye::Bye;
34pub use tokio_util::sync::CancellationToken;
35
36#[cfg(feature = "tracing")]
37use tracing::{error, info, warn};
38
39mod bye;
40
41const READY_ENV: &str = "UPGRADE_FD";
42static READY_LAST_PID: AtomicU32 = AtomicU32::new(0);
43
44#[derive(Debug)]
45struct UpgradeUsr1 {
46 exe_path: CString,
47 args: Vec<CString>,
48 env: Vec<CString>,
49}
50
51impl UpgradeUsr1 {
52 pub fn new(exe_path: &Path) -> bye::Result<Self> {
53 let exe_path = CString::new(exe_path.as_os_str().as_bytes())?;
54 let args = std::env::args_os()
55 .map(|arg| CString::new(arg.as_bytes()))
56 .collect::<Result<Vec<_>, _>>()?;
57
58 let env = std::env::vars_os()
59 .map(|(key, value)| {
60 let mut kv = Vec::new();
61 kv.extend(key.as_bytes());
62 kv.push(b'=');
63 kv.extend(value.as_bytes());
64 CString::new(kv)
65 })
66 .collect::<Result<Vec<_>, _>>()?;
67
68 Ok(Self {
69 exe_path,
70 args,
71 env,
72 })
73 }
74
75 pub async fn upgrade(&self, timeout: Option<Duration>) -> bye::Result<bool> {
76 fork_and_exec(&self.exe_path, &self.args, &self.env, timeout).await
77 }
78}
79
80fn cstr_kv_eq_ignore_ascii(e: &CString, key: &[u8]) -> bool {
81 let bytes = e.as_bytes();
82 if let Some(eq) = memchr::memchr(b'=', bytes) {
83 bytes[..eq].eq_ignore_ascii_case(key)
84 } else {
85 false
86 }
87}
88
89fn replace_or_push_env<D: Display>(env: &mut Vec<CString>, key: &str, val: D) -> bye::Result<()> {
90 let kbytes = key.as_bytes();
91 if let Some(idx) = env.iter().position(|e| cstr_kv_eq_ignore_ascii(e, kbytes)) {
92 env[idx] = CString::new(format!("{key}={val}"))?;
93 } else {
94 env.push(CString::new(format!("{key}={val}"))?);
95 }
96 Ok(())
97}
98
99fn get_env_u32(env: &[CString], key: &str) -> Option<u32> {
100 let kbytes = key.as_bytes();
101 for e in env {
102 let bytes = e.as_bytes();
103 if let Some(eq) = memchr::memchr(b'=', bytes) {
104 if bytes[..eq].eq_ignore_ascii_case(kbytes) {
105 if let Ok(s) = std::str::from_utf8(&bytes[eq + 1..]) {
106 if let Ok(n) = s.parse::<u32>() {
107 return Some(n);
108 }
109 }
110 }
111 }
112 }
113 None
114}
115
116async fn fork_and_exec(
117 path: &CString,
118 args: &[CString],
119 env: &[CString],
120 timeout: Option<Duration>,
121) -> bye::Result<bool> {
122 let (read_fd, write_fd) = pipe2(OFlag::O_CLOEXEC).map_err(Error::Pipe2)?;
123
124 clear_cloexec(&write_fd)?;
125 set_nonblocking(&read_fd)?;
126
127 let mut env_with_fd = Vec::with_capacity(env.len() + 1);
128 env_with_fd.extend(
129 env.iter()
130 .filter(|e| !cstr_kv_eq_ignore_ascii(e, READY_ENV.as_bytes()))
131 .cloned(),
132 );
133 env_with_fd.push(CString::new(format!(
134 "{}={}",
135 READY_ENV,
136 write_fd.as_raw_fd()
137 ))?);
138
139 let argv_ref: Vec<&_> = args.iter().map(std::ffi::CString::as_c_str).collect();
141
142 match unsafe { fork().map_err(Error::Fork)? } {
143 ForkResult::Parent { child } => {
144 drop(env_with_fd);
145
146 drop(write_fd);
148
149 let af = AsyncFd::new(read_fd)?;
150 let fut = async move {
153 let mut buf = [0; 1];
154 loop {
155 let mut guard = af.readable().await?;
156 match read(&af, &mut buf) {
157 Ok(0) => return Ok(false),
158 Ok(_) => return Ok(true),
159 Err(Errno::EAGAIN) => {
160 guard.clear_ready();
162 }
163 Err(Errno::EINTR) => {
164 guard.clear_ready();
166 }
167 Err(e) => {
168 return Err(Error::Nix(e));
169 }
170 }
171 }
172 };
173
174 let result = if let Some(dur) = timeout {
175 tokio::time::timeout(dur, fut).await
176 } else {
177 Ok(fut.await)
178 };
179
180 match result {
181 Ok(Ok(true)) => Ok(true),
182 Ok(Ok(false)) => Ok(false),
183 Ok(Err(e)) => Err(e),
184 Err(_) => {
185 use nix::sys::signal::{Signal, kill};
186 kill(child, Signal::SIGKILL).map_err(|e| Error::KillChild {
187 pid: child.into(),
188 source: e,
189 })?;
190 waitpid(child, None).map_err(Error::WaitPid)?;
191 Err(Error::ChildTimeout)
192 }
193 }
194 }
195 ForkResult::Child => {
196 drop(read_fd);
197
198 if let Some(nfds) = get_env_u32(&env_with_fd, "LISTEN_FDS") {
199 if nfds > 0 {
200 let child_pid = std::process::id();
201 replace_or_push_env(&mut env_with_fd, "LISTEN_PID", child_pid).ok();
202
203 for i in 0..nfds {
204 #[allow(clippy::cast_possible_wrap)]
205 let fd = (3 + i) as i32;
206 let borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
207 if let Err(e) = clear_cloexec(&borrowed) {
208 #[cfg(feature = "tracing")]
209 error!("failed to clear cloexec on fd {}: {}", fd, e);
210 std::process::exit(127);
211 }
212 }
213 }
214 }
215
216 let envp: Vec<&_> = env_with_fd.iter().map(std::ffi::CString::as_c_str).collect();
217
218 execve(path, &argv_ref, &envp).map_err(Error::Execve)?;
219 std::process::exit(127);
220 }
221 }
222}
223
224fn parse_fd_from_env(v: &OsStr) -> bye::Result<i32> {
225 let s = std::str::from_utf8(v.as_bytes()).map_err(|e| Error::EnvUtf8 {
226 key: "UPGRADE_FD",
227 source: e,
228 })?;
229 let fd = s.parse::<i32>().map_err(|e| Error::EnvParse {
230 key: "UPGRADE_FD",
231 source: e,
232 })?;
233 if fd < 0 {
234 return Err(Error::InvalidUpgradeFd);
235 }
236 Ok(fd)
237}
238
239pub fn try_pid_file() -> Result<PathBuf, VarError> {
245 Ok(std::env::var("PIDFILE")?.into())
246}
247
248static SYSTEMD_PORTS: LazyLock<Vec<u16>> = std::sync::LazyLock::new(|| {
249 compute_systemd_ports().unwrap_or_else(|e| {
250 #[cfg(feature = "tracing")]
251 error!("Failed to compute systemd ports: {}", e);
252 vec![]
253 })
254});
255
256#[must_use]
259pub fn systemd_ports() -> &'static [u16] {
260 &SYSTEMD_PORTS
261}
262
263fn compute_systemd_ports() -> bye::Result<Vec<u16>> {
264 let listen_fds = std::env::var("LISTEN_FDS")
265 .unwrap_or("0".to_string())
266 .parse::<u32>()
267 .map_err(|e| Error::EnvParse {
268 key: "LISTEN_FDS",
269 source: e,
270 })?;
271
272 let listen_pid = std::env::var("LISTEN_PID")
273 .unwrap_or("0".to_string())
274 .parse::<u32>()
275 .map_err(|e| Error::EnvParse {
276 key: "LISTEN_PID",
277 source: e,
278 })?;
279
280 if listen_fds == 0 || listen_pid != std::process::id() {
281 return Err(Error::SystemdActivation);
282 }
283
284 let mut ports = Vec::with_capacity(listen_fds as usize);
285 for i in 0..listen_fds {
286 #[allow(clippy::cast_possible_wrap)]
287 let fd = (3 + i) as i32;
288
289 let port = match getsockname::<nix::sys::socket::SockaddrStorage>(fd) {
290 Ok(v) => v
291 .as_sockaddr_in()
292 .map(nix::sys::socket::SockaddrIn::port)
293 .or_else(|| v.as_sockaddr_in6().map(nix::sys::socket::SockaddrIn6::port))
294 .unwrap_or(0),
295 Err(e) => {
296 return Err(Error::Sockname(e));
297 }
298 };
299 if port != 0 {
300 ports.push(port);
301 }
302 }
303
304 Ok(ports)
305}
306
307pub async fn systemd_tcp_listener(port: u16) -> bye::Result<TcpListener> {
313 let systemd_ports = SYSTEMD_PORTS.iter().position(|&p| p == port);
314 let listener = if let Some(systemd_port) = systemd_ports {
315 #[allow(clippy::cast_possible_truncation)]
316 #[allow(clippy::cast_possible_wrap)]
317 let fd = (3 + systemd_port) as i32;
318 #[cfg(feature = "tracing")]
319 info!("using systemd socket fd: {}", fd);
320 let raw_listener = unsafe { std::net::TcpListener::from_raw_fd(fd) };
321 raw_listener.set_nonblocking(true)?;
322 TcpListener::from_std(raw_listener)?
323 } else {
324 #[cfg(feature = "tracing")]
325 warn!("no systemd socket found for port {}", port);
326 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
327 TcpListener::bind(addr).await?
328 };
329
330 Ok(listener)
331}
332
333pub fn ready() -> bye::Result<()> {
340 let pid = std::process::id();
341 let prev = READY_LAST_PID.swap(pid, std::sync::atomic::Ordering::AcqRel);
342
343 if prev == pid {
344 return Ok(());
345 }
346
347 if let Some(val) = std::env::var_os(READY_ENV) {
348 let fd = parse_fd_from_env(&val)?;
349 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
350 fcntl(&fd, FcntlArg::F_GETFD).map_err(|e| Error::Fcntl {
351 op: "F_GETFD",
352 source: e,
353 })?;
354
355 loop {
356 let n = nix::unistd::write(&fd, &[1u8]);
357 match n {
358 Ok(_) => break,
359 Err(Errno::EINTR | Errno::EAGAIN) => {}
360 Err(Errno::EPIPE) => {
361 #[cfg(feature = "tracing")]
362 warn!("UPGRADE_FD is closed, ignoring write");
363 break;
364 }
365 Err(e) => {
366 return Err(Error::NotifyWrite(e));
367 }
368 }
369 }
370 }
371
372 if let Ok(pid_file) = try_pid_file() {
373 let pid = std::process::id();
374 std::fs::write(pid_file, pid.to_string())?;
375 }
376
377 Ok(())
378}
379
380fn clear_cloexec<F: AsFd>(fd: &F) -> bye::Result<()> {
381 let getfd = fcntl(fd, FcntlArg::F_GETFD).map_err(|e| Error::Fcntl {
382 op: "F_GETFD",
383 source: e,
384 })?;
385 let flags = FdFlag::from_bits_truncate(getfd);
386 let new_flags = flags.difference(FdFlag::FD_CLOEXEC);
387 fcntl(fd, FcntlArg::F_SETFD(new_flags)).map_err(|e| Error::Fcntl {
388 op: "F_SETFD",
389 source: e,
390 })?;
391 Ok(())
392}
393
394fn set_nonblocking(fd: &OwnedFd) -> bye::Result<()> {
395 let getfl = fcntl(fd, FcntlArg::F_GETFL).map_err(|e| Error::Fcntl {
396 op: "F_GETFL",
397 source: e,
398 })?;
399 let flags = OFlag::from_bits_truncate(getfl);
400 let new_flags = flags | OFlag::O_NONBLOCK;
401 fcntl(fd, FcntlArg::F_SETFL(new_flags)).map_err(|e| Error::Fcntl {
402 op: "F_SETFL",
403 source: e,
404 })?;
405 Ok(())
406}