Skip to main content

proc_connector/
socket.rs

1//! Core `ProcConnector` type — safe wrapper around a Linux Netlink Proc Connector socket.
2//!
3//! # Lifecycle
4//!
5//! 1. `ProcConnector::new()` — creates the netlink socket, binds to `CN_IDX_PROC`, subscribes.
6//! 2. `recv()` / `recv_timeout()` — receive and parse process events.
7//! 3. `unsubscribe()` / `subscribe()` — toggle subscription (useful after reconnect).
8//! 4. Drop — automatically unsubscribes and closes the socket.
9
10use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
11use std::time::Duration;
12
13use crate::consts::*;
14use crate::error::{Error, Result};
15
16/// A safe handle to a Linux Netlink Proc Connector socket.
17///
18/// The socket is created, bound, and subscribed in `new()`. On drop, the
19/// subscription is cancelled and the file descriptor is closed automatically.
20///
21/// # Examples
22///
23/// ```no_run
24/// use proc_connector::ProcConnector;
25///
26/// let conn = ProcConnector::new().unwrap();
27/// let mut buf = vec![0u8; 4096];
28///
29/// loop {
30///     match conn.recv(&mut buf) {
31///         Ok(event) => println!("got event: {event:?}"),
32///         Err(e) => eprintln!("error: {e}"),
33///     }
34/// }
35/// ```
36pub struct ProcConnector {
37    fd: OwnedFd,
38}
39
40impl ProcConnector {
41    /// Create a new `ProcConnector`.
42    ///
43    /// This is a convenience constructor that:
44    /// 1. Creates a `PF_NETLINK` / `SOCK_DGRAM` socket of family `NETLINK_CONNECTOR`.
45    /// 2. Binds to the `CN_IDX_PROC` multicast group.
46    /// 3. Sends a `PROC_CN_MCAST_LISTEN` subscription message.
47    ///
48    /// # Errors
49    ///
50    /// Returns `Error::Os` if any system call fails.
51    pub fn new() -> Result<Self> {
52        let fd = create_socket()?;
53        let connector = Self { fd };
54        connector.bind()?;
55        connector.subscribe()?;
56        Ok(connector)
57    }
58
59    /// Bind the socket to the `CN_IDX_PROC` netlink group.
60    fn bind(&self) -> Result<()> {
61        let mut sa: libc::sockaddr_nl = unsafe { std::mem::zeroed() };
62        sa.nl_family = libc::AF_NETLINK as u16;
63        // nl_pad left as zeroed
64        sa.nl_pid = 0; // let kernel pick a port ID
65        sa.nl_groups = CN_IDX_PROC;
66
67        let ret = unsafe {
68            libc::bind(
69                self.fd.as_raw_fd(),
70                &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
71                std::mem::size_of::<libc::sockaddr_nl>() as u32,
72            )
73        };
74
75        if ret < 0 {
76            return Err(Error::Os(std::io::Error::last_os_error()));
77        }
78        Ok(())
79    }
80
81        /// (Re-)subscribe to process events.
82    ///
83    /// Sends a `PROC_CN_MCAST_LISTEN` message via the netlink socket.
84    /// This is safe to call multiple times (e.g. after a reconnection).
85    ///
86    /// # Example
87    ///
88    /// ```no_run
89    /// # use proc_connector::ProcConnector;
90    /// let mut conn = ProcConnector::new().unwrap();
91    /// conn.subscribe().expect("subscribe");
92    /// ```
93    pub fn subscribe(&self) -> Result<()> {
94        self.send_mcast_op(PROC_CN_MCAST_LISTEN)
95    }
96
97    /// Unsubscribe from process events.
98    ///
99    /// Sends a `PROC_CN_MCAST_IGNORE` message. Automatically called on drop.
100    ///
101    /// # Example
102    ///
103    /// ```no_run
104    /// # use proc_connector::ProcConnector;
105    /// let conn = ProcConnector::new().unwrap();
106    /// conn.unsubscribe().expect("unsubscribe");
107    /// // Re-subscribe later:
108    /// conn.subscribe().expect("re-subscribe");
109    /// ```
110    pub fn unsubscribe(&self) -> Result<()> {
111        self.send_mcast_op(PROC_CN_MCAST_IGNORE)
112    }
113
114    /// Send a `proc_cn_mcast_op` command to the kernel.
115    fn send_mcast_op(&self, op: u32) -> Result<()> {
116        let nlmsg_payload_len = SIZE_CN_MSG + std::mem::size_of::<u32>();
117        let nlmsg_len = nlmsg_length(nlmsg_payload_len);
118
119        let mut buf = vec![0u8; nlmsg_len];
120        let pid = std::process::id();
121
122        // --- nlmsghdr (16 bytes, little-endian wire format) ---
123        let hdr = &mut buf[..SIZE_NLMSGHDR];
124        hdr[0..4].copy_from_slice(&(nlmsg_len as u32).to_ne_bytes()); // nlmsg_len
125        hdr[4..6].copy_from_slice(&NLMSG_MIN_TYPE.to_ne_bytes());      // nlmsg_type (≥ 16 = application-defined)
126        hdr[6..8].copy_from_slice(&NLM_F_REQUEST.to_ne_bytes());      // nlmsg_flags (request)
127        hdr[8..12].copy_from_slice(&0u32.to_ne_bytes());              // nlmsg_seq
128        hdr[12..16].copy_from_slice(&pid.to_ne_bytes());              // nlmsg_pid
129
130        // --- cn_msg (20 bytes header + op payload) ---
131        let cn_off = nlmsg_hdrlen();
132        let cn = &mut buf[cn_off..cn_off + SIZE_CN_MSG + std::mem::size_of::<u32>()];
133        // id.idx
134        cn[0..4].copy_from_slice(&CN_IDX_PROC.to_ne_bytes());
135        // id.val
136        cn[4..8].copy_from_slice(&CN_VAL_PROC.to_ne_bytes());
137        // seq
138        cn[8..12].copy_from_slice(&0u32.to_ne_bytes());
139        // ack
140        cn[12..16].copy_from_slice(&0u32.to_ne_bytes());
141        // len (u16) = sizeof(proc_cn_mcast_op) = 4
142        cn[16..18]
143            .copy_from_slice(&(std::mem::size_of::<u32>() as u16).to_ne_bytes());
144        // flags
145        cn[18..20].copy_from_slice(&0u16.to_ne_bytes());
146        // data = proc_cn_mcast_op
147        cn[20..24].copy_from_slice(&op.to_ne_bytes());
148
149        let iov = libc::iovec {
150            iov_base: buf.as_mut_ptr() as *mut libc::c_void,
151            iov_len: nlmsg_len,
152        };
153
154        let msg_hdr = libc::msghdr {
155            msg_name: std::ptr::null_mut(),
156            msg_namelen: 0,
157            msg_iov: &iov as *const libc::iovec as *mut libc::iovec,
158            msg_iovlen: 1,
159            msg_control: std::ptr::null_mut(),
160            msg_controllen: 0,
161            msg_flags: 0,
162        };
163
164        let ret = unsafe { libc::sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) };
165        if ret < 0 {
166            return Err(Error::Os(std::io::Error::last_os_error()));
167        }
168        Ok(())
169    }
170
171    /// Receive a raw netlink message into `buf`.
172    ///
173    /// On success returns the number of bytes written to `buf`.
174    ///
175    /// This is a thin wrapper around `recv(2)`. The caller is responsible
176    /// for providing a sufficiently large buffer (a page size, 4096 bytes,
177    /// is a safe default).
178    ///
179    /// # Errors
180    ///
181    /// - `Interrupted` if `EINTR` — retry the call.
182    /// - `ConnectionClosed` if recv returns 0.
183    /// - `Os` for other system call errors.
184    pub fn recv_raw(&self, buf: &mut [u8]) -> Result<usize> {
185        let len = unsafe {
186            libc::recv(
187                self.fd.as_raw_fd(),
188                buf.as_mut_ptr() as *mut libc::c_void,
189                buf.len(),
190                0,
191            )
192        };
193
194        if len < 0 {
195            let err = std::io::Error::last_os_error();
196            return match err.raw_os_error() {
197                Some(libc::EINTR) => Err(Error::Interrupted),
198                Some(libc::EAGAIN) => Err(Error::WouldBlock), // EWOULDBLOCK == EAGAIN on Linux
199                _ => Err(Error::Os(err)),
200            };
201        }
202
203        if len == 0 {
204            return Err(Error::ConnectionClosed);
205        }
206
207        Ok(len as usize)
208    }
209
210    /// Receive a raw netlink message with a timeout.
211    ///
212    /// Returns `Ok(None)` if the timeout expires before data is available.
213    /// Otherwise behaves the same as `recv_raw`.
214    ///
215    /// Uses `poll(2)` internally and only calls `recv` when data is ready.
216    ///
217    /// # Example
218    ///
219    /// ```no_run
220    /// # use proc_connector::ProcConnector;
221    /// # use std::time::Duration;
222    /// let conn = ProcConnector::new().unwrap();
223    /// let mut buf = vec![0u8; 4096];
224    ///
225    /// match conn.recv_timeout(&mut buf, Duration::from_secs(5)) {
226    ///     Ok(Some(event)) => println!("{event}"),
227    ///     Ok(None) => eprintln!("timeout, no event in 5s"),
228    ///     Err(e) => eprintln!("error: {e}"),
229    /// }
230    /// ```
231    pub fn recv_raw_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<Option<usize>> {
232        let poll_fd = libc::pollfd {
233            fd: self.fd.as_raw_fd(),
234            events: libc::POLLIN,
235            revents: 0,
236        };
237
238        let timeout_ms = timeout
239            .as_millis()
240            .try_into()
241            .unwrap_or(libc::c_int::MAX);
242
243        let ret = unsafe { libc::poll(&poll_fd as *const libc::pollfd as *mut _, 1, timeout_ms) };
244
245        if ret < 0 {
246            let err = std::io::Error::last_os_error();
247            return match err.raw_os_error() {
248                Some(libc::EINTR) => Err(Error::Interrupted),
249                _ => Err(Error::Os(err)),
250            };
251        }
252
253        if ret == 0 {
254            return Ok(None);
255        }
256
257        // recv_raw may return WouldBlock if the fd was set to non-blocking
258        // between poll() and recv(). Treat as timeout.
259        match self.recv_raw(buf) {
260            Ok(n) => Ok(Some(n)),
261            Err(Error::WouldBlock) => Ok(None),
262            Err(e) => Err(e),
263        }
264    }
265
266    /// Expose the raw file descriptor for integration with async runtimes
267    /// (`tokio::AsyncFd`, `mio`, etc.).
268    ///
269    /// The returned `RawFd` remains valid for the lifetime of this
270    /// `ProcConnector`. Do not close it manually.
271    ///
272    /// # Example
273    ///
274    /// ```no_run
275    /// # use proc_connector::ProcConnector;
276    /// # use std::os::fd::AsRawFd;
277    /// let conn = ProcConnector::new().unwrap();
278    /// let raw = conn.as_raw_fd();
279    /// assert!(raw >= 0);
280    ///
281    /// // Use with tokio:
282    /// // let async_fd = tokio::io::unix::AsyncFd::new(conn).unwrap();
283    /// ```
284    pub fn as_raw_fd(&self) -> RawFd {
285        self.fd.as_raw_fd()
286    }
287
288    /// Set the socket to non-blocking mode.
289    ///
290    /// After calling this, `recv_raw` will return `Error::WouldBlock`
291    /// when no data is available, instead of blocking.
292    ///
293    /// # Errors
294    ///
295    /// Returns `Error::Os` if `fcntl(2)` fails.
296    pub fn set_nonblocking(&self) -> Result<()> {
297        let flags = unsafe { libc::fcntl(self.fd.as_raw_fd(), libc::F_GETFL) };
298        if flags < 0 {
299            return Err(Error::Os(std::io::Error::last_os_error()));
300        }
301        let ret = unsafe { libc::fcntl(self.fd.as_raw_fd(), libc::F_SETFL, flags | libc::O_NONBLOCK) };
302        if ret < 0 {
303            return Err(Error::Os(std::io::Error::last_os_error()));
304        }
305        Ok(())
306    }
307}
308
309impl AsRawFd for ProcConnector {
310    fn as_raw_fd(&self) -> RawFd {
311        self.fd.as_raw_fd()
312    }
313}
314
315impl AsFd for ProcConnector {
316    fn as_fd(&self) -> BorrowedFd<'_> {
317        self.fd.as_fd()
318    }
319}
320
321impl Drop for ProcConnector {
322    fn drop(&mut self) {
323        // Best-effort unsubscribe; ignore errors since we're closing anyway.
324        let _ = self.unsubscribe();
325    }
326}
327
328// ---------------------------------------------------------------------------
329// Helper: create the netlink socket
330// ---------------------------------------------------------------------------
331
332fn create_socket() -> Result<OwnedFd> {
333    let fd = unsafe {
334        let fd = libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, NETLINK_CONNECTOR);
335        if fd < 0 {
336            return Err(Error::Os(std::io::Error::last_os_error()));
337        }
338        OwnedFd::from_raw_fd(fd)
339    };
340
341    // Enable NETLINK_NO_ENOBUFS so the kernel doesn't silently drop
342    // our subscription message when the socket buffer is full.
343    let val: libc::c_int = 1;
344    let ret = unsafe {
345        libc::setsockopt(
346            fd.as_raw_fd(),
347            libc::SOL_NETLINK,
348            NETLINK_NO_ENOBUFS,
349            &val as *const libc::c_int as *const libc::c_void,
350            std::mem::size_of::<libc::c_int>() as u32,
351        )
352    };
353    if ret < 0 {
354        // Non-fatal: proceed even if this option fails.
355        let _ = std::io::Error::last_os_error();
356    }
357
358    Ok(fd)
359}