Skip to main content

proc_connector/
socket.rs

1//! Core `ProcConnector` type — safe wrapper around a Linux Netlink Proc Connector socket.
2//!
3//! # Lifecycle
4//!
5//! 1. `ProcConnector::new()` — creates the netlink socket, binds to `CN_IDX_PROC`, subscribes.
6//! 2. `recv()` / `recv_timeout()` — receive and parse process events.
7//! 3. `unsubscribe()` / `subscribe()` — toggle subscription (useful after reconnect).
8//! 4. Drop — automatically unsubscribes and closes the socket.
9
10use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
11use std::time::Duration;
12
13use crate::consts::*;
14use crate::error::{Error, Result};
15use crate::proc_event::ProcEvent;
16
17/// A safe handle to a Linux Netlink Proc Connector socket.
18///
19/// The socket is created, bound, and subscribed in `new()`. On drop, the
20/// subscription is cancelled and the file descriptor is closed automatically.
21///
22/// # Examples
23///
24/// ```no_run
25/// use proc_connector::ProcConnector;
26///
27/// let conn = ProcConnector::new().unwrap();
28/// let mut buf = vec![0u8; 4096];
29///
30/// loop {
31///     match conn.recv(&mut buf) {
32///         Ok(event) => println!("got event: {event:?}"),
33///         Err(e) => eprintln!("error: {e}"),
34///     }
35/// }
36/// ```
37pub struct ProcConnector {
38    fd: OwnedFd,
39}
40
41impl ProcConnector {
42    /// Create a new `ProcConnector`.
43    ///
44    /// This is a convenience constructor that:
45    /// 1. Creates a `PF_NETLINK` / `SOCK_DGRAM` socket of family `NETLINK_CONNECTOR`.
46    /// 2. Binds to the `CN_IDX_PROC` multicast group.
47    /// 3. Sends a `PROC_CN_MCAST_LISTEN` subscription message.
48    ///
49    /// # Errors
50    ///
51    /// Returns `Error::Os` if any system call fails.
52    pub fn new() -> Result<Self> {
53        let fd = create_socket()?;
54        let connector = Self { fd };
55        connector.bind()?;
56        connector.subscribe()?;
57        Ok(connector)
58    }
59
60    /// Bind the socket to the `CN_IDX_PROC` netlink group.
61    fn bind(&self) -> Result<()> {
62        let mut sa: libc::sockaddr_nl = unsafe { std::mem::zeroed() };
63        sa.nl_family = libc::AF_NETLINK as u16;
64        // nl_pad left as zeroed
65        sa.nl_pid = 0; // let kernel pick a port ID
66        sa.nl_groups = CN_IDX_PROC;
67
68        let ret = unsafe {
69            libc::bind(
70                self.fd.as_raw_fd(),
71                &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
72                std::mem::size_of::<libc::sockaddr_nl>() as u32,
73            )
74        };
75
76        if ret < 0 {
77            return Err(Error::Os(std::io::Error::last_os_error()));
78        }
79        Ok(())
80    }
81
82    /// (Re-)subscribe to process events.
83    ///
84    /// Sends a `PROC_CN_MCAST_LISTEN` message via the netlink socket.
85    /// This is safe to call multiple times (e.g. after a reconnection).
86    ///
87    /// # Example
88    ///
89    /// ```no_run
90    /// # use proc_connector::ProcConnector;
91    /// let mut conn = ProcConnector::new().unwrap();
92    /// conn.subscribe().expect("subscribe");
93    /// ```
94    pub fn subscribe(&self) -> Result<()> {
95        self.send_mcast_op(PROC_CN_MCAST_LISTEN)
96    }
97
98    /// Unsubscribe from process events.
99    ///
100    /// Sends a `PROC_CN_MCAST_IGNORE` message. Automatically called on drop.
101    ///
102    /// # Example
103    ///
104    /// ```no_run
105    /// # use proc_connector::ProcConnector;
106    /// let conn = ProcConnector::new().unwrap();
107    /// conn.unsubscribe().expect("unsubscribe");
108    /// // Re-subscribe later:
109    /// conn.subscribe().expect("re-subscribe");
110    /// ```
111    pub fn unsubscribe(&self) -> Result<()> {
112        self.send_mcast_op(PROC_CN_MCAST_IGNORE)
113    }
114
115    /// Send a `proc_cn_mcast_op` command to the kernel.
116    fn send_mcast_op(&self, op: u32) -> Result<()> {
117        let nlmsg_payload_len = SIZE_CN_MSG + std::mem::size_of::<u32>();
118        let nlmsg_len = nlmsg_length(nlmsg_payload_len);
119
120        let mut buf = vec![0u8; nlmsg_len];
121        let pid = std::process::id();
122
123        // --- nlmsghdr (16 bytes, little-endian wire format) ---
124        let hdr = &mut buf[..SIZE_NLMSGHDR];
125        hdr[0..4].copy_from_slice(&(nlmsg_len as u32).to_ne_bytes()); // nlmsg_len
126        hdr[4..6].copy_from_slice(&NLMSG_MIN_TYPE.to_ne_bytes()); // nlmsg_type (≥ 16 = application-defined)
127        hdr[6..8].copy_from_slice(&NLM_F_REQUEST.to_ne_bytes()); // nlmsg_flags (request)
128        hdr[8..12].copy_from_slice(&0u32.to_ne_bytes()); // nlmsg_seq
129        hdr[12..16].copy_from_slice(&pid.to_ne_bytes()); // nlmsg_pid
130
131        // --- cn_msg (20 bytes header + op payload) ---
132        let cn_off = nlmsg_hdrlen();
133        let cn = &mut buf[cn_off..cn_off + SIZE_CN_MSG + std::mem::size_of::<u32>()];
134        // id.idx
135        cn[0..4].copy_from_slice(&CN_IDX_PROC.to_ne_bytes());
136        // id.val
137        cn[4..8].copy_from_slice(&CN_VAL_PROC.to_ne_bytes());
138        // seq
139        cn[8..12].copy_from_slice(&0u32.to_ne_bytes());
140        // ack
141        cn[12..16].copy_from_slice(&0u32.to_ne_bytes());
142        // len (u16) = sizeof(proc_cn_mcast_op) = 4
143        cn[16..18].copy_from_slice(&(std::mem::size_of::<u32>() as u16).to_ne_bytes());
144        // flags
145        cn[18..20].copy_from_slice(&0u16.to_ne_bytes());
146        // data = proc_cn_mcast_op
147        cn[20..24].copy_from_slice(&op.to_ne_bytes());
148
149        let iov = libc::iovec {
150            iov_base: buf.as_mut_ptr() as *mut libc::c_void,
151            iov_len: nlmsg_len,
152        };
153
154        let msg_hdr = libc::msghdr {
155            msg_name: std::ptr::null_mut(),
156            msg_namelen: 0,
157            msg_iov: &iov as *const libc::iovec as *mut libc::iovec,
158            msg_iovlen: 1,
159            msg_control: std::ptr::null_mut(),
160            msg_controllen: 0,
161            msg_flags: 0,
162        };
163
164        let ret = unsafe { libc::sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) };
165        if ret < 0 {
166            return Err(Error::Os(std::io::Error::last_os_error()));
167        }
168        Ok(())
169    }
170
171    /// Receive a raw netlink message into `buf`.
172    ///
173    /// On success returns the number of bytes written to `buf`.
174    ///
175    /// This is a thin wrapper around `recv(2)`. The caller is responsible
176    /// for providing a sufficiently large buffer (a page size, 4096 bytes,
177    /// is a safe default).
178    ///
179    /// # Errors
180    ///
181    /// - `Interrupted` if `EINTR` — retry the call.
182    /// - `ConnectionClosed` if recv returns 0.
183    /// - `Os` for other system call errors.
184    pub fn recv_raw(&self, buf: &mut [u8]) -> Result<usize> {
185        let len = unsafe {
186            libc::recv(
187                self.fd.as_raw_fd(),
188                buf.as_mut_ptr() as *mut libc::c_void,
189                buf.len(),
190                0,
191            )
192        };
193
194        if len < 0 {
195            let err = std::io::Error::last_os_error();
196            return match err.raw_os_error() {
197                Some(libc::EINTR) => Err(Error::Interrupted),
198                Some(libc::EAGAIN) => Err(Error::WouldBlock), // EWOULDBLOCK == EAGAIN on Linux
199                _ => Err(Error::Os(err)),
200            };
201        }
202
203        if len == 0 {
204            return Err(Error::ConnectionClosed);
205        }
206
207        Ok(len as usize)
208    }
209
210    /// Receive a raw netlink message with a timeout.
211    ///
212    /// Returns `Ok(None)` if the timeout expires before data is available.
213    /// Otherwise behaves the same as `recv_raw`.
214    ///
215    /// Uses `poll(2)` internally and only calls `recv` when data is ready.
216    ///
217    /// # Example
218    ///
219    /// ```no_run
220    /// # use proc_connector::ProcConnector;
221    /// # use std::time::Duration;
222    /// let conn = ProcConnector::new().unwrap();
223    /// let mut buf = vec![0u8; 4096];
224    ///
225    /// match conn.recv_timeout(&mut buf, Duration::from_secs(5)) {
226    ///     Ok(Some(event)) => println!("{event}"),
227    ///     Ok(None) => eprintln!("timeout, no event in 5s"),
228    ///     Err(e) => eprintln!("error: {e}"),
229    /// }
230    /// ```
231    pub fn recv_raw_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<Option<usize>> {
232        let poll_fd = libc::pollfd {
233            fd: self.fd.as_raw_fd(),
234            events: libc::POLLIN,
235            revents: 0,
236        };
237
238        let timeout_ms = timeout.as_millis().try_into().unwrap_or(libc::c_int::MAX);
239
240        let ret = unsafe { libc::poll(&poll_fd as *const libc::pollfd as *mut _, 1, timeout_ms) };
241
242        if ret < 0 {
243            let err = std::io::Error::last_os_error();
244            return match err.raw_os_error() {
245                Some(libc::EINTR) => Err(Error::Interrupted),
246                _ => Err(Error::Os(err)),
247            };
248        }
249
250        if ret == 0 {
251            return Ok(None);
252        }
253
254        // recv_raw may return WouldBlock if the fd was set to non-blocking
255        // between poll() and recv(). Treat as timeout.
256        match self.recv_raw(buf) {
257            Ok(n) => Ok(Some(n)),
258            Err(Error::WouldBlock) => Ok(None),
259            Err(e) => Err(e),
260        }
261    }
262
263    /// Receive and parse the next process event.
264    ///
265    /// `buf` is the receive buffer provided by the caller (allocation control).
266    /// A buffer of at least 4096 bytes (one page) is recommended.
267    ///
268    /// This method handles all netlink control messages internally:
269    /// - `NLMSG_NOOP` → silently skipped, continue reading
270    /// - `NLMSG_DONE` (with no payload) → silently skipped, continue reading
271    ///   (The kernel connector protocol uses `NLMSG_DONE` with a cn_msg payload
272    ///   for data messages, which are parsed as events.)
273    /// - `NLMSG_ERROR` (non-zero) → returned as `Err(Os(...))`
274    /// - `NLMSG_OVERRUN` → returned as `Err(Overrun)`
275    /// - Valid data → parsed into `ProcEvent`
276    ///
277    /// # Errors
278    ///
279    /// See [`recv_raw`] for system-level errors.
280    /// Additionally returns `BufferTooSmall` if the buffer is too small
281    /// to hold even a single netlink header.
282    ///
283    /// # Example
284    ///
285    /// ```no_run
286    /// use proc_connector::ProcConnector;
287    ///
288    /// let conn = ProcConnector::new().unwrap();
289    /// let mut buf = [0u8; 4096];
290    /// loop {
291    ///     match conn.recv(&mut buf) {
292    ///         Ok(event) => println!("{event}"),
293    ///         Err(e) => { eprintln!("{e}"); break; }
294    ///     }
295    /// }
296    /// ```
297    pub fn recv(&self, buf: &mut [u8]) -> Result<ProcEvent> {
298        self.recv_impl(buf)
299    }
300
301    /// Receive and parse the next process event with a timeout.
302    ///
303    /// Returns `Ok(None)` if the timeout expires before an event is available.
304    ///
305    /// Unlike `recv()`, this method returns `Ok(None)` on timeout instead of
306    /// blocking indefinitely. It properly loops past netlink control messages
307    /// (NLMSG_NOOP, NLMSG_DONE, NLMSG_ERROR-ACK) just like `recv()` does.
308    ///
309    /// # Errors
310    ///
311    /// See [`recv_timeout`] for system-level errors.
312    pub fn recv_timeout(
313        &self,
314        buf: &mut [u8],
315        timeout: std::time::Duration,
316    ) -> Result<Option<ProcEvent>> {
317        if buf.len() < SIZE_NLMSGHDR {
318            return Err(Error::BufferTooSmall {
319                needed: SIZE_NLMSGHDR,
320            });
321        }
322
323        loop {
324            let n = match self.recv_raw_timeout(buf, timeout) {
325                Ok(Some(n)) => n,
326                Ok(None) => return Ok(None),
327                Err(Error::WouldBlock) => {
328                    // Non-blocking mode — should not happen with timeout
329                    // since recv_raw_timeout uses poll(). Treat as timeout.
330                    return Ok(None);
331                }
332                Err(e) => return Err(e),
333            };
334
335            // Iterate over all messages in the buffer, skipping control messages
336            if let Some(event) = crate::parse::first_event_from_buf(buf, n)? {
337                return Ok(Some(event));
338            }
339            // All messages were control messages — loop back and wait for a real event
340        }
341    }
342
343    /// Internal: block until a process event is received.
344    ///
345    /// Loops past netlink control messages (NLMSG_NOOP, NLMSG_DONE, NLMSG_ERROR-ACK)
346    /// until a real ProcEvent is parsed.
347    fn recv_impl(&self, buf: &mut [u8]) -> Result<ProcEvent> {
348        if buf.len() < SIZE_NLMSGHDR {
349            return Err(Error::BufferTooSmall {
350                needed: SIZE_NLMSGHDR,
351            });
352        }
353
354        loop {
355            let n = match self.recv_raw(buf) {
356                Ok(n) => n,
357                Err(Error::WouldBlock) => {
358                    // Non-blocking mode and no data — caller should use
359                    // poll/AsyncFd instead of blocking recv.
360                    // Return a non-recoverable error for blocking API.
361                    return Err(Error::Os(std::io::Error::new(
362                        std::io::ErrorKind::WouldBlock,
363                        "socket is non-blocking, use AsyncFd to wait for readiness",
364                    )));
365                }
366                Err(e) => return Err(e),
367            };
368
369            // Parse all messages in the buffer
370            if let Some(event) = crate::parse::first_event_from_buf(buf, n)? {
371                return Ok(event);
372            }
373
374            // If we got here, all messages were control messages.
375            // Loop back and wait for a real event.
376        }
377    }
378
379    /// Expose the raw file descriptor for integration with async runtimes
380    /// (`tokio::AsyncFd`, `mio`, etc.).
381    ///
382    /// The returned `RawFd` remains valid for the lifetime of this
383    /// `ProcConnector`. Do not close it manually.
384    ///
385    /// # Example
386    ///
387    /// ```no_run
388    /// # use proc_connector::ProcConnector;
389    /// # use std::os::fd::AsRawFd;
390    /// let conn = ProcConnector::new().unwrap();
391    /// let raw = conn.as_raw_fd();
392    /// assert!(raw >= 0);
393    ///
394    /// // Use with tokio:
395    /// // let async_fd = tokio::io::unix::AsyncFd::new(conn).unwrap();
396    /// ```
397    pub fn as_raw_fd(&self) -> RawFd {
398        self.fd.as_raw_fd()
399    }
400
401    /// Set the socket to non-blocking mode.
402    ///
403    /// After calling this, `recv_raw` will return `Error::WouldBlock`
404    /// when no data is available, instead of blocking.
405    ///
406    /// # Errors
407    ///
408    /// Returns `Error::Os` if `fcntl(2)` fails.
409    pub fn set_nonblocking(&self) -> Result<()> {
410        let flags = unsafe { libc::fcntl(self.fd.as_raw_fd(), libc::F_GETFL) };
411        if flags < 0 {
412            return Err(Error::Os(std::io::Error::last_os_error()));
413        }
414        let ret =
415            unsafe { libc::fcntl(self.fd.as_raw_fd(), libc::F_SETFL, flags | libc::O_NONBLOCK) };
416        if ret < 0 {
417            return Err(Error::Os(std::io::Error::last_os_error()));
418        }
419        Ok(())
420    }
421}
422
423impl AsRawFd for ProcConnector {
424    fn as_raw_fd(&self) -> RawFd {
425        self.fd.as_raw_fd()
426    }
427}
428
429impl AsFd for ProcConnector {
430    fn as_fd(&self) -> BorrowedFd<'_> {
431        self.fd.as_fd()
432    }
433}
434
435impl Drop for ProcConnector {
436    fn drop(&mut self) {
437        // Best-effort unsubscribe; ignore errors since we're closing anyway.
438        let _ = self.unsubscribe();
439    }
440}
441
442// ---------------------------------------------------------------------------
443// Helper: create the netlink socket
444// ---------------------------------------------------------------------------
445
446fn create_socket() -> Result<OwnedFd> {
447    let fd = unsafe {
448        let fd = libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, NETLINK_CONNECTOR);
449        if fd < 0 {
450            return Err(Error::Os(std::io::Error::last_os_error()));
451        }
452        OwnedFd::from_raw_fd(fd)
453    };
454
455    // Enable NETLINK_NO_ENOBUFS so the kernel doesn't silently drop
456    // our subscription message when the socket buffer is full.
457    let val: libc::c_int = 1;
458    let ret = unsafe {
459        libc::setsockopt(
460            fd.as_raw_fd(),
461            libc::SOL_NETLINK,
462            NETLINK_NO_ENOBUFS,
463            &val as *const libc::c_int as *const libc::c_void,
464            std::mem::size_of::<libc::c_int>() as u32,
465        )
466    };
467    if ret < 0 {
468        // Non-fatal: proceed even if this option fails.
469        let _ = std::io::Error::last_os_error();
470    }
471
472    Ok(fd)
473}