Skip to main content

coreshift_core/
unix_socket.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/
4
5//! Low-level Unix domain socket primitives.
6//!
7//! This module exposes Linux/Android `AF_UNIX` stream socket mechanics only:
8//! bind, listen, accept, connect, chmod for filesystem sockets, peer
9//! credentials, and byte I/O through [`Fd`]. Callers own all protocol, message
10//! framing, authentication policy, daemon behavior, and socket naming.
11//!
12//! Abstract socket names are Linux/Android-only. They are encoded with a
13//! leading NUL byte in `sun_path`; interior NUL bytes in the caller-provided
14//! abstract name are preserved because the kernel uses the explicit sockaddr
15//! length, not C string termination.
16
17use crate::CoreError;
18use crate::error::syscall_ret;
19use crate::reactor::Fd;
20use std::io::Error as IoError;
21use std::os::unix::ffi::OsStrExt;
22use std::os::unix::fs::FileTypeExt;
23use std::os::unix::io::AsRawFd;
24use std::path::Path;
25
26#[inline(always)]
27fn errno() -> i32 {
28    IoError::last_os_error().raw_os_error().unwrap_or(0)
29}
30
31/// Owned non-blocking Unix listener descriptor.
32pub struct UnixListenerFd {
33    /// Underlying descriptor for reactor registration and raw byte helpers.
34    pub fd: Fd,
35}
36
37/// Owned non-blocking Unix stream descriptor.
38pub struct UnixStreamFd {
39    /// Underlying descriptor for reactor registration and raw byte helpers.
40    pub fd: Fd,
41}
42
43/// Result of starting a non-blocking Unix stream connection.
44pub enum UnixConnectResult {
45    /// The socket connected immediately.
46    Connected(UnixStreamFd),
47    /// The socket connection is in progress; register for writability and call
48    /// [`UnixStreamFd::finish_connect`] or [`UnixStreamFd::check_connect_error`].
49    InProgress(UnixStreamFd),
50}
51
52/// Unix socket address.
53#[derive(Clone, Copy, Debug)]
54pub enum UnixSocketAddr<'a> {
55    /// Filesystem pathname socket.
56    Path(&'a Path),
57    /// Linux/Android abstract namespace socket name, without the leading NUL.
58    Abstract(&'a [u8]),
59}
60
61/// Explicit stale pathname behavior for filesystem socket binds.
62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
63pub enum StaleSocketPolicy {
64    /// Preserve any existing path and let `bind` report the conflict.
65    #[default]
66    Preserve,
67    /// Unlink only if the existing path is itself a socket.
68    UnlinkSocketOnly,
69    /// Unlink any existing filesystem path.
70    ///
71    /// This may delete non-socket files and should only be used when the caller
72    /// owns the path namespace.
73    UnlinkAnyPath,
74}
75
76/// Bind options for a Unix stream listener.
77#[derive(Clone, Copy, Debug, Default)]
78pub struct UnixSocketBindOptions {
79    /// Explicit stale pathname handling for filesystem socket binds.
80    pub stale_socket_policy: StaleSocketPolicy,
81    /// Optional filesystem socket path mode applied after a successful bind.
82    pub mode: Option<u32>,
83}
84
85/// Peer process credentials when the platform exposes them.
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub struct PeerCred {
88    /// Peer process id when available.
89    pub pid: Option<i32>,
90    /// Peer user id.
91    pub uid: u32,
92    /// Peer group id.
93    pub gid: u32,
94}
95
96impl UnixListenerFd {
97    /// Accept one non-blocking client.
98    ///
99    /// Returns `Ok(None)` if no client is ready.
100    pub fn accept(&self) -> Result<Option<UnixStreamFd>, CoreError> {
101        loop {
102            let fd = unsafe {
103                libc::accept4(
104                    self.fd.as_raw_fd(),
105                    std::ptr::null_mut(),
106                    std::ptr::null_mut(),
107                    libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
108                )
109            };
110            if fd >= 0 {
111                return Ok(Some(UnixStreamFd {
112                    fd: Fd::new(fd, "accept4")?,
113                }));
114            }
115
116            let e = errno();
117            if e == libc::EINTR {
118                continue;
119            }
120            if e == libc::EAGAIN || e == libc::EWOULDBLOCK {
121                return Ok(None);
122            }
123            return Err(CoreError::sys(e, "accept4"));
124        }
125    }
126}
127
128impl UnixStreamFd {
129    /// Return peer credentials when the platform supports `SO_PEERCRED`.
130    pub fn peer_cred(&self) -> Result<Option<PeerCred>, CoreError> {
131        peer_cred_raw(&self.fd)
132    }
133
134    /// Return the pending `SO_ERROR` connect status.
135    ///
136    /// `Ok(None)` means no pending socket error was reported. `Ok(Some(code))`
137    /// returns the raw connect error without making a policy decision.
138    pub fn check_connect_error(&self) -> Result<Option<i32>, CoreError> {
139        let mut code: libc::c_int = 0;
140        let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
141        let ret = unsafe {
142            libc::getsockopt(
143                self.fd.as_raw_fd(),
144                libc::SOL_SOCKET,
145                libc::SO_ERROR,
146                (&mut code as *mut libc::c_int).cast(),
147                &mut len,
148            )
149        };
150        syscall_ret(ret, "getsockopt(SO_ERROR)")?;
151        if code == 0 { Ok(None) } else { Ok(Some(code)) }
152    }
153
154    /// Finish a non-blocking connect after the socket becomes writable.
155    ///
156    /// Returns the stream when `SO_ERROR` is clear; otherwise returns the raw
157    /// socket error as [`CoreError`].
158    pub fn finish_connect(self) -> Result<Self, CoreError> {
159        match self.check_connect_error()? {
160            None => Ok(self),
161            Some(code) => Err(CoreError::sys(code, "connect(SO_ERROR)")),
162        }
163    }
164}
165
166/// Bind and listen on a non-blocking Unix stream socket.
167pub fn bind_unix_listener(
168    addr: UnixSocketAddr<'_>,
169    opts: UnixSocketBindOptions,
170) -> Result<UnixListenerFd, CoreError> {
171    let encoded = UnixSockAddr::new(addr, "unix bind address")?;
172
173    match addr {
174        UnixSocketAddr::Path(path) => {
175            apply_stale_socket_policy(path, opts.stale_socket_policy)?;
176        }
177        UnixSocketAddr::Abstract(_) => {
178            if opts.stale_socket_policy != StaleSocketPolicy::Preserve || opts.mode.is_some() {
179                return Err(CoreError::sys(libc::EINVAL, "abstract unix bind options"));
180            }
181        }
182    }
183
184    let fd = new_unix_stream_socket()?;
185    let ret = unsafe { libc::bind(fd.as_raw_fd(), encoded.as_ptr(), encoded.len()) };
186    syscall_ret(ret, "bind")?;
187
188    if let (UnixSocketAddr::Path(path), Some(mode)) = (addr, opts.mode) {
189        if let Err(err) = chmod_unix_socket(UnixSocketAddr::Path(path), mode) {
190            cleanup_created_path(addr);
191            return Err(err);
192        }
193    }
194
195    let ret = unsafe { libc::listen(fd.as_raw_fd(), libc::SOMAXCONN) };
196    if let Err(err) = syscall_ret(ret, "listen") {
197        cleanup_created_path(addr);
198        return Err(err);
199    }
200
201    Ok(UnixListenerFd { fd })
202}
203
204/// Connect a non-blocking Unix stream socket.
205pub fn connect_unix_stream(addr: UnixSocketAddr<'_>) -> Result<UnixConnectResult, CoreError> {
206    let encoded = UnixSockAddr::new(addr, "unix connect address")?;
207    let fd = new_unix_stream_socket()?;
208
209    loop {
210        let ret = unsafe { libc::connect(fd.as_raw_fd(), encoded.as_ptr(), encoded.len()) };
211        if ret == 0 {
212            return Ok(UnixConnectResult::Connected(UnixStreamFd { fd }));
213        }
214
215        let e = errno();
216        if e == libc::EINTR {
217            continue;
218        }
219        if e == libc::EINPROGRESS || e == libc::EALREADY {
220            return Ok(UnixConnectResult::InProgress(UnixStreamFd { fd }));
221        }
222        if e == libc::EISCONN {
223            return Ok(UnixConnectResult::Connected(UnixStreamFd { fd }));
224        }
225        return Err(CoreError::sys(e, "connect"));
226    }
227}
228
229/// Change mode bits on a Unix socket filesystem path.
230pub fn chmod_unix_socket(addr: UnixSocketAddr<'_>, mode: u32) -> Result<(), CoreError> {
231    match addr {
232        UnixSocketAddr::Path(path) => {
233            let metadata = std::fs::symlink_metadata(path).map_err(|err| {
234                CoreError::sys(
235                    err.raw_os_error().unwrap_or(libc::EIO),
236                    "lstat unix socket path",
237                )
238            })?;
239            if !metadata.file_type().is_socket() {
240                return Err(CoreError::sys(libc::EINVAL, "chmod unix socket path"));
241            }
242            let c_path = path_cstring(path, "chmod unix socket path")?;
243            let ret = unsafe { libc::chmod(c_path.as_ptr(), mode as libc::mode_t) };
244            syscall_ret(ret, "chmod")
245        }
246        UnixSocketAddr::Abstract(_) => Err(CoreError::sys(libc::EINVAL, "chmod abstract socket")),
247    }
248}
249
250/// Change mode bits on a Unix socket filesystem path.
251pub fn chmod_socket_path(path: impl AsRef<Path>, mode: u32) -> Result<(), CoreError> {
252    chmod_unix_socket(UnixSocketAddr::Path(path.as_ref()), mode)
253}
254
255fn new_unix_stream_socket() -> Result<Fd, CoreError> {
256    let fd = unsafe {
257        libc::socket(
258            libc::AF_UNIX,
259            libc::SOCK_STREAM | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
260            0,
261        )
262    };
263    syscall_ret(fd, "socket(AF_UNIX)")?;
264    Fd::new(fd, "socket(AF_UNIX)")
265}
266
267fn apply_stale_socket_policy(path: &Path, policy: StaleSocketPolicy) -> Result<(), CoreError> {
268    match policy {
269        StaleSocketPolicy::Preserve => Ok(()),
270        StaleSocketPolicy::UnlinkSocketOnly => {
271            let metadata = match std::fs::symlink_metadata(path) {
272                Ok(metadata) => metadata,
273                Err(err) if err.raw_os_error() == Some(libc::ENOENT) => return Ok(()),
274                Err(err) => {
275                    return Err(CoreError::sys(
276                        err.raw_os_error().unwrap_or(libc::EIO),
277                        "lstat unix socket path",
278                    ));
279                }
280            };
281            if !metadata.file_type().is_socket() {
282                return Err(CoreError::sys(libc::EEXIST, "stale unix socket path"));
283            }
284            unlink_path(path, "unlink stale unix socket")
285        }
286        StaleSocketPolicy::UnlinkAnyPath => unlink_path(path, "unlink unix socket path"),
287    }
288}
289
290fn unlink_path(path: &Path, op: &'static str) -> Result<(), CoreError> {
291    match std::fs::remove_file(path) {
292        Ok(()) => Ok(()),
293        Err(err) if err.raw_os_error() == Some(libc::ENOENT) => Ok(()),
294        Err(err) => Err(CoreError::sys(err.raw_os_error().unwrap_or(libc::EIO), op)),
295    }
296}
297
298fn cleanup_created_path(addr: UnixSocketAddr<'_>) {
299    if let UnixSocketAddr::Path(path) = addr {
300        let _ = std::fs::remove_file(path);
301    }
302}
303
304struct UnixSockAddr {
305    inner: libc::sockaddr_un,
306    len: libc::socklen_t,
307}
308
309impl UnixSockAddr {
310    fn new(addr: UnixSocketAddr<'_>, op: &'static str) -> Result<Self, CoreError> {
311        let mut inner: libc::sockaddr_un = unsafe { std::mem::zeroed() };
312        inner.sun_family = libc::AF_UNIX as libc::sa_family_t;
313        let sun_path_offset = std::mem::offset_of!(libc::sockaddr_un, sun_path);
314
315        let len = match addr {
316            UnixSocketAddr::Path(path) => {
317                let bytes = path.as_os_str().as_bytes();
318                if bytes.is_empty() {
319                    return Err(CoreError::sys(libc::EINVAL, op));
320                }
321                if bytes.contains(&0) {
322                    return Err(CoreError::sys(libc::EINVAL, op));
323                }
324                if bytes.len() >= inner.sun_path.len() {
325                    return Err(CoreError::sys(libc::ENAMETOOLONG, op));
326                }
327
328                for (slot, byte) in inner.sun_path.iter_mut().zip(bytes.iter().copied()) {
329                    *slot = byte as libc::c_char;
330                }
331                sun_path_offset + bytes.len() + 1
332            }
333            UnixSocketAddr::Abstract(name) => {
334                validate_abstract_supported()?;
335                if name.is_empty() {
336                    return Err(CoreError::sys(libc::EINVAL, op));
337                }
338                if name.len() + 1 > inner.sun_path.len() {
339                    return Err(CoreError::sys(libc::ENAMETOOLONG, op));
340                }
341
342                inner.sun_path[0] = 0;
343                for (slot, byte) in inner.sun_path[1..].iter_mut().zip(name.iter().copied()) {
344                    *slot = byte as libc::c_char;
345                }
346                sun_path_offset + 1 + name.len()
347            }
348        };
349        let len = libc::socklen_t::try_from(len).map_err(|_| CoreError::sys(libc::EINVAL, op))?;
350
351        Ok(Self { inner, len })
352    }
353
354    fn len(&self) -> libc::socklen_t {
355        self.len
356    }
357
358    fn as_ptr(&self) -> *const libc::sockaddr {
359        (&self.inner as *const libc::sockaddr_un).cast()
360    }
361}
362
363fn validate_abstract_supported() -> Result<(), CoreError> {
364    if cfg!(any(target_os = "linux", target_os = "android")) {
365        Ok(())
366    } else {
367        Err(CoreError::sys(libc::ENOSYS, "abstract unix socket"))
368    }
369}
370
371fn path_cstring(path: &Path, op: &'static str) -> Result<std::ffi::CString, CoreError> {
372    std::ffi::CString::new(path.as_os_str().as_bytes())
373        .map_err(|_| CoreError::sys(libc::EINVAL, op))
374}
375
376#[cfg(any(target_os = "linux", target_os = "android"))]
377fn peer_cred_raw(fd: &Fd) -> Result<Option<PeerCred>, CoreError> {
378    let mut cred: libc::ucred = unsafe { std::mem::zeroed() };
379    let mut len = std::mem::size_of::<libc::ucred>() as libc::socklen_t;
380    let ret = unsafe {
381        libc::getsockopt(
382            fd.as_raw_fd(),
383            libc::SOL_SOCKET,
384            libc::SO_PEERCRED,
385            (&mut cred as *mut libc::ucred).cast(),
386            &mut len,
387        )
388    };
389    syscall_ret(ret, "getsockopt(SO_PEERCRED)")?;
390
391    Ok(Some(PeerCred {
392        pid: Some(cred.pid),
393        uid: cred.uid,
394        gid: cred.gid,
395    }))
396}
397
398#[cfg(not(any(target_os = "linux", target_os = "android")))]
399fn peer_cred_raw(_fd: &Fd) -> Result<Option<PeerCred>, CoreError> {
400    Ok(None)
401}