proc_connector/socket.rs
1//! Core `ProcConnector` type — safe wrapper around a Linux Netlink Proc Connector socket.
2//!
3//! # Lifecycle
4//!
5//! 1. `ProcConnector::new()` — creates the netlink socket, binds to `CN_IDX_PROC`, subscribes.
6//! 2. `recv()` / `recv_timeout()` — receive and parse process events.
7//! 3. `unsubscribe()` / `subscribe()` — toggle subscription (useful after reconnect).
8//! 4. Drop — automatically unsubscribes and closes the socket.
9
10use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
11use std::time::Duration;
12
13use crate::consts::*;
14use crate::error::{Error, Result};
15
16/// A safe handle to a Linux Netlink Proc Connector socket.
17///
18/// The socket is created, bound, and subscribed in `new()`. On drop, the
19/// subscription is cancelled and the file descriptor is closed automatically.
20///
21/// # Examples
22///
23/// ```no_run
24/// use proc_connector::ProcConnector;
25///
26/// let conn = ProcConnector::new().unwrap();
27/// let mut buf = vec![0u8; 4096];
28///
29/// loop {
30/// match conn.recv(&mut buf) {
31/// Ok(event) => println!("got event: {event:?}"),
32/// Err(e) => eprintln!("error: {e}"),
33/// }
34/// }
35/// ```
36pub struct ProcConnector {
37 fd: OwnedFd,
38}
39
40impl ProcConnector {
41 /// Create a new `ProcConnector`.
42 ///
43 /// This is a convenience constructor that:
44 /// 1. Creates a `PF_NETLINK` / `SOCK_DGRAM` socket of family `NETLINK_CONNECTOR`.
45 /// 2. Binds to the `CN_IDX_PROC` multicast group.
46 /// 3. Sends a `PROC_CN_MCAST_LISTEN` subscription message.
47 ///
48 /// # Errors
49 ///
50 /// Returns `Error::Os` if any system call fails.
51 pub fn new() -> Result<Self> {
52 let fd = create_socket()?;
53 let connector = Self { fd };
54 connector.bind()?;
55 connector.subscribe()?;
56 Ok(connector)
57 }
58
59 /// Bind the socket to the `CN_IDX_PROC` netlink group.
60 fn bind(&self) -> Result<()> {
61 let mut sa: libc::sockaddr_nl = unsafe { std::mem::zeroed() };
62 sa.nl_family = libc::AF_NETLINK as u16;
63 // nl_pad left as zeroed
64 sa.nl_pid = 0; // let kernel pick a port ID
65 sa.nl_groups = CN_IDX_PROC;
66
67 let ret = unsafe {
68 libc::bind(
69 self.fd.as_raw_fd(),
70 &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
71 std::mem::size_of::<libc::sockaddr_nl>() as u32,
72 )
73 };
74
75 if ret < 0 {
76 return Err(Error::Os(std::io::Error::last_os_error()));
77 }
78 Ok(())
79 }
80
81 /// (Re-)subscribe to process events.
82 ///
83 /// Sends a `PROC_CN_MCAST_LISTEN` message via the netlink socket.
84 /// This is safe to call multiple times (e.g. after a reconnection).
85 ///
86 /// # Example
87 ///
88 /// ```no_run
89 /// # use proc_connector::ProcConnector;
90 /// let mut conn = ProcConnector::new().unwrap();
91 /// conn.subscribe().expect("subscribe");
92 /// ```
93 pub fn subscribe(&self) -> Result<()> {
94 self.send_mcast_op(PROC_CN_MCAST_LISTEN)
95 }
96
97 /// Unsubscribe from process events.
98 ///
99 /// Sends a `PROC_CN_MCAST_IGNORE` message. Automatically called on drop.
100 ///
101 /// # Example
102 ///
103 /// ```no_run
104 /// # use proc_connector::ProcConnector;
105 /// let conn = ProcConnector::new().unwrap();
106 /// conn.unsubscribe().expect("unsubscribe");
107 /// // Re-subscribe later:
108 /// conn.subscribe().expect("re-subscribe");
109 /// ```
110 pub fn unsubscribe(&self) -> Result<()> {
111 self.send_mcast_op(PROC_CN_MCAST_IGNORE)
112 }
113
114 /// Send a `proc_cn_mcast_op` command to the kernel.
115 fn send_mcast_op(&self, op: u32) -> Result<()> {
116 let nlmsg_payload_len = SIZE_CN_MSG + std::mem::size_of::<u32>();
117 let nlmsg_len = nlmsg_length(nlmsg_payload_len);
118
119 let mut buf = vec![0u8; nlmsg_len];
120 let pid = std::process::id();
121
122 // --- nlmsghdr (16 bytes, little-endian wire format) ---
123 let hdr = &mut buf[..SIZE_NLMSGHDR];
124 hdr[0..4].copy_from_slice(&(nlmsg_len as u32).to_ne_bytes()); // nlmsg_len
125 hdr[4..6].copy_from_slice(&NLMSG_MIN_TYPE.to_ne_bytes()); // nlmsg_type (≥ 16 = application-defined)
126 hdr[6..8].copy_from_slice(&NLM_F_REQUEST.to_ne_bytes()); // nlmsg_flags (request)
127 hdr[8..12].copy_from_slice(&0u32.to_ne_bytes()); // nlmsg_seq
128 hdr[12..16].copy_from_slice(&pid.to_ne_bytes()); // nlmsg_pid
129
130 // --- cn_msg (20 bytes header + op payload) ---
131 let cn_off = nlmsg_hdrlen();
132 let cn = &mut buf[cn_off..cn_off + SIZE_CN_MSG + std::mem::size_of::<u32>()];
133 // id.idx
134 cn[0..4].copy_from_slice(&CN_IDX_PROC.to_ne_bytes());
135 // id.val
136 cn[4..8].copy_from_slice(&CN_VAL_PROC.to_ne_bytes());
137 // seq
138 cn[8..12].copy_from_slice(&0u32.to_ne_bytes());
139 // ack
140 cn[12..16].copy_from_slice(&0u32.to_ne_bytes());
141 // len (u16) = sizeof(proc_cn_mcast_op) = 4
142 cn[16..18]
143 .copy_from_slice(&(std::mem::size_of::<u32>() as u16).to_ne_bytes());
144 // flags
145 cn[18..20].copy_from_slice(&0u16.to_ne_bytes());
146 // data = proc_cn_mcast_op
147 cn[20..24].copy_from_slice(&op.to_ne_bytes());
148
149 let iov = libc::iovec {
150 iov_base: buf.as_mut_ptr() as *mut libc::c_void,
151 iov_len: nlmsg_len,
152 };
153
154 let msg_hdr = libc::msghdr {
155 msg_name: std::ptr::null_mut(),
156 msg_namelen: 0,
157 msg_iov: &iov as *const libc::iovec as *mut libc::iovec,
158 msg_iovlen: 1,
159 msg_control: std::ptr::null_mut(),
160 msg_controllen: 0,
161 msg_flags: 0,
162 };
163
164 let ret = unsafe { libc::sendmsg(self.fd.as_raw_fd(), &msg_hdr, 0) };
165 if ret < 0 {
166 return Err(Error::Os(std::io::Error::last_os_error()));
167 }
168 Ok(())
169 }
170
171 /// Receive a raw netlink message into `buf`.
172 ///
173 /// On success returns the number of bytes written to `buf`.
174 ///
175 /// This is a thin wrapper around `recv(2)`. The caller is responsible
176 /// for providing a sufficiently large buffer (a page size, 4096 bytes,
177 /// is a safe default).
178 ///
179 /// # Errors
180 ///
181 /// - `Interrupted` if `EINTR` — retry the call.
182 /// - `ConnectionClosed` if recv returns 0.
183 /// - `Os` for other system call errors.
184 pub fn recv_raw(&self, buf: &mut [u8]) -> Result<usize> {
185 let len = unsafe {
186 libc::recv(
187 self.fd.as_raw_fd(),
188 buf.as_mut_ptr() as *mut libc::c_void,
189 buf.len(),
190 0,
191 )
192 };
193
194 if len < 0 {
195 let err = std::io::Error::last_os_error();
196 return match err.raw_os_error() {
197 Some(libc::EINTR) => Err(Error::Interrupted),
198 Some(libc::EAGAIN) => Err(Error::WouldBlock), // EWOULDBLOCK == EAGAIN on Linux
199 _ => Err(Error::Os(err)),
200 };
201 }
202
203 if len == 0 {
204 return Err(Error::ConnectionClosed);
205 }
206
207 Ok(len as usize)
208 }
209
210 /// Receive a raw netlink message with a timeout.
211 ///
212 /// Returns `Ok(None)` if the timeout expires before data is available.
213 /// Otherwise behaves the same as `recv_raw`.
214 ///
215 /// Uses `poll(2)` internally and only calls `recv` when data is ready.
216 ///
217 /// # Example
218 ///
219 /// ```no_run
220 /// # use proc_connector::ProcConnector;
221 /// # use std::time::Duration;
222 /// let conn = ProcConnector::new().unwrap();
223 /// let mut buf = vec![0u8; 4096];
224 ///
225 /// match conn.recv_timeout(&mut buf, Duration::from_secs(5)) {
226 /// Ok(Some(event)) => println!("{event}"),
227 /// Ok(None) => eprintln!("timeout, no event in 5s"),
228 /// Err(e) => eprintln!("error: {e}"),
229 /// }
230 /// ```
231 pub fn recv_raw_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<Option<usize>> {
232 let poll_fd = libc::pollfd {
233 fd: self.fd.as_raw_fd(),
234 events: libc::POLLIN,
235 revents: 0,
236 };
237
238 let timeout_ms = timeout
239 .as_millis()
240 .try_into()
241 .unwrap_or(libc::c_int::MAX);
242
243 let ret = unsafe { libc::poll(&poll_fd as *const libc::pollfd as *mut _, 1, timeout_ms) };
244
245 if ret < 0 {
246 let err = std::io::Error::last_os_error();
247 return match err.raw_os_error() {
248 Some(libc::EINTR) => Err(Error::Interrupted),
249 _ => Err(Error::Os(err)),
250 };
251 }
252
253 if ret == 0 {
254 return Ok(None);
255 }
256
257 // recv_raw may return WouldBlock if the fd was set to non-blocking
258 // between poll() and recv(). Treat as timeout.
259 match self.recv_raw(buf) {
260 Ok(n) => Ok(Some(n)),
261 Err(Error::WouldBlock) => Ok(None),
262 Err(e) => Err(e),
263 }
264 }
265
266 /// Expose the raw file descriptor for integration with async runtimes
267 /// (`tokio::AsyncFd`, `mio`, etc.).
268 ///
269 /// The returned `RawFd` remains valid for the lifetime of this
270 /// `ProcConnector`. Do not close it manually.
271 ///
272 /// # Example
273 ///
274 /// ```no_run
275 /// # use proc_connector::ProcConnector;
276 /// # use std::os::fd::AsRawFd;
277 /// let conn = ProcConnector::new().unwrap();
278 /// let raw = conn.as_raw_fd();
279 /// assert!(raw >= 0);
280 ///
281 /// // Use with tokio:
282 /// // let async_fd = tokio::io::unix::AsyncFd::new(conn).unwrap();
283 /// ```
284 pub fn as_raw_fd(&self) -> RawFd {
285 self.fd.as_raw_fd()
286 }
287
288 /// Set the socket to non-blocking mode.
289 ///
290 /// After calling this, `recv_raw` will return `Error::WouldBlock`
291 /// when no data is available, instead of blocking.
292 ///
293 /// # Errors
294 ///
295 /// Returns `Error::Os` if `fcntl(2)` fails.
296 pub fn set_nonblocking(&self) -> Result<()> {
297 let flags = unsafe { libc::fcntl(self.fd.as_raw_fd(), libc::F_GETFL) };
298 if flags < 0 {
299 return Err(Error::Os(std::io::Error::last_os_error()));
300 }
301 let ret = unsafe { libc::fcntl(self.fd.as_raw_fd(), libc::F_SETFL, flags | libc::O_NONBLOCK) };
302 if ret < 0 {
303 return Err(Error::Os(std::io::Error::last_os_error()));
304 }
305 Ok(())
306 }
307}
308
309impl AsRawFd for ProcConnector {
310 fn as_raw_fd(&self) -> RawFd {
311 self.fd.as_raw_fd()
312 }
313}
314
315impl AsFd for ProcConnector {
316 fn as_fd(&self) -> BorrowedFd<'_> {
317 self.fd.as_fd()
318 }
319}
320
321impl Drop for ProcConnector {
322 fn drop(&mut self) {
323 // Best-effort unsubscribe; ignore errors since we're closing anyway.
324 let _ = self.unsubscribe();
325 }
326}
327
328// ---------------------------------------------------------------------------
329// Helper: create the netlink socket
330// ---------------------------------------------------------------------------
331
332fn create_socket() -> Result<OwnedFd> {
333 let fd = unsafe {
334 let fd = libc::socket(libc::PF_NETLINK, libc::SOCK_DGRAM, NETLINK_CONNECTOR);
335 if fd < 0 {
336 return Err(Error::Os(std::io::Error::last_os_error()));
337 }
338 OwnedFd::from_raw_fd(fd)
339 };
340
341 // Enable NETLINK_NO_ENOBUFS so the kernel doesn't silently drop
342 // our subscription message when the socket buffer is full.
343 let val: libc::c_int = 1;
344 let ret = unsafe {
345 libc::setsockopt(
346 fd.as_raw_fd(),
347 libc::SOL_NETLINK,
348 NETLINK_NO_ENOBUFS,
349 &val as *const libc::c_int as *const libc::c_void,
350 std::mem::size_of::<libc::c_int>() as u32,
351 )
352 };
353 if ret < 0 {
354 // Non-fatal: proceed even if this option fails.
355 let _ = std::io::Error::last_os_error();
356 }
357
358 Ok(fd)
359}