Skip to main content

jsonrpc_fdpass/
transport.rs

1use crate::error::{Error, Result};
2use crate::message::{JsonRpcMessage, JsonRpcNotification, MessageWithFds, get_fd_count};
3use rustix::fd::AsFd;
4use rustix::net::{
5    RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
6    SendAncillaryMessage, SendFlags,
7};
8use serde::Serialize;
9use std::collections::VecDeque;
10use std::io::{self, IoSlice, IoSliceMut};
11use std::mem::MaybeUninit;
12use std::num::NonZeroUsize;
13use std::os::unix::io::OwnedFd;
14use std::sync::Arc;
15use tokio::io::Interest;
16use tokio::net::UnixStream as TokioUnixStream;
17use tracing::{debug, trace};
18
19/// Default maximum number of file descriptors per sendmsg() call.
20///
21/// Platform limits for SCM_RIGHTS vary (e.g., ~253 on Linux, ~512 on macOS).
22/// We start with an optimistic value; if sendmsg() fails with EINVAL, the
23/// batch size is automatically reduced and the send is retried.
24pub const DEFAULT_MAX_FDS_PER_SENDMSG: NonZeroUsize = NonZeroUsize::new(500).unwrap();
25
26/// Maximum FDs to expect in a single recvmsg() call.
27/// Must be at least as large as the largest platform limit (~512 on macOS).
28const MAX_FDS_PER_RECVMSG: usize = 512;
29
30/// Read buffer size for incoming data.
31const READ_BUFFER_SIZE: usize = 4096;
32
33/// Transport layer for Unix socket communication with file descriptor passing.
34pub struct UnixSocketTransport {
35    stream: TokioUnixStream,
36}
37
38impl UnixSocketTransport {
39    /// Create a new transport from an existing Unix stream.
40    pub fn new(stream: TokioUnixStream) -> Self {
41        Self { stream }
42    }
43
44    /// Split the transport into separate sender and receiver halves.
45    pub fn split(self) -> (Sender, Receiver) {
46        let stream = Arc::new(self.stream);
47
48        (
49            Sender {
50                stream: Arc::clone(&stream),
51                pretty: false,
52                max_fds_per_sendmsg: DEFAULT_MAX_FDS_PER_SENDMSG,
53            },
54            Receiver {
55                stream,
56                buffer: Vec::new(),
57                fd_queue: VecDeque::new(),
58                pending_message: None,
59            },
60        )
61    }
62}
63
64/// Sender half of a Unix socket transport for sending JSON-RPC messages.
65pub struct Sender {
66    stream: Arc<TokioUnixStream>,
67    pretty: bool,
68    /// Maximum FDs to send per sendmsg() call. Configurable for testing.
69    max_fds_per_sendmsg: NonZeroUsize,
70}
71
72impl Sender {
73    /// Enable or disable pretty-printed JSON output.
74    ///
75    /// When enabled, messages are serialized with indentation and newlines.
76    /// This is useful for debugging or when interoperating with tools that
77    /// expect human-readable JSON.
78    pub fn set_pretty(&mut self, pretty: bool) {
79        self.pretty = pretty;
80    }
81
82    /// Set the maximum number of file descriptors to send per sendmsg() call.
83    ///
84    /// This is primarily useful for testing FD batching behavior. The default
85    /// value ([`DEFAULT_MAX_FDS_PER_SENDMSG`]) is optimistic and may exceed
86    /// some platform limits; if sendmsg() returns `EINVAL`, the batch size is
87    /// automatically reduced and the send is retried.
88    pub fn set_max_fds_per_sendmsg(&mut self, max_fds: NonZeroUsize) {
89        self.max_fds_per_sendmsg = max_fds;
90    }
91
92    /// Send a notification without file descriptors.
93    ///
94    /// This is a convenience method that serializes the params and constructs
95    /// the notification message automatically.
96    pub async fn notify<P: Serialize>(&mut self, method: &str, params: P) -> Result<()> {
97        self.notify_with_fds(method, params, Vec::new()).await
98    }
99
100    /// Send a notification with file descriptors.
101    ///
102    /// This is a convenience method that serializes the params and constructs
103    /// the notification message automatically.
104    pub async fn notify_with_fds<P: Serialize>(
105        &mut self,
106        method: &str,
107        params: P,
108        fds: Vec<OwnedFd>,
109    ) -> Result<()> {
110        let params_value = serde_json::to_value(params)?;
111        let params_opt = if params_value.is_null() {
112            None
113        } else {
114            Some(params_value)
115        };
116        let notification = JsonRpcNotification::new(method.to_string(), params_opt);
117        let message = JsonRpcMessage::Notification(notification);
118        let message_with_fds = MessageWithFds::new(message, fds);
119        self.send(message_with_fds).await
120    }
121
122    /// Send a JSON-RPC message with optional file descriptors.
123    pub async fn send(&mut self, message_with_fds: MessageWithFds) -> Result<()> {
124        let serialized = if self.pretty {
125            message_with_fds.serialize_pretty()?
126        } else {
127            message_with_fds.serialize()?
128        };
129        let data = serialized.into_bytes();
130
131        trace!(
132            "Sending message: {} with {} FDs",
133            String::from_utf8_lossy(&data).trim(),
134            message_with_fds.file_descriptors.len()
135        );
136
137        let fds = message_with_fds.file_descriptors;
138
139        // Track how many bytes and FDs we've sent so far
140        let mut bytes_sent = 0usize;
141        let mut fds_sent = 0usize;
142
143        // Current max FDs per batch - may be reduced if we hit EINVAL
144        let mut current_max_fds = self.max_fds_per_sendmsg.get();
145
146        // Send data with FDs in batches. Each sendmsg can only handle a limited number of FDs.
147        // We send FDs with the data chunks, and any remaining FDs after all data is sent.
148        while bytes_sent < data.len() || fds_sent < fds.len() {
149            let remaining_data = &data[bytes_sent..];
150            let remaining_fds = &fds[fds_sent..];
151
152            // Determine how many FDs to send in this batch (up to current_max_fds)
153            let fds_batch = remaining_fds
154                .get(..current_max_fds)
155                .unwrap_or(remaining_fds);
156
157            let result = self
158                .stream
159                .async_io(Interest::WRITABLE, || {
160                    let sockfd = self.stream.as_fd();
161
162                    if !fds_batch.is_empty() {
163                        // Send with FDs using sendmsg with ancillary data
164                        let borrowed_fds: Vec<_> = fds_batch.iter().map(|fd| fd.as_fd()).collect();
165
166                        let mut buffer: [MaybeUninit<u8>;
167                            rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))] =
168                            [MaybeUninit::uninit();
169                                rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))];
170                        let mut control = SendAncillaryBuffer::new(&mut buffer);
171
172                        if !control.push(SendAncillaryMessage::ScmRights(&borrowed_fds)) {
173                            return Err(io::Error::other(
174                                "Failed to add file descriptors to control message",
175                            ));
176                        }
177
178                        // If we have data to send, include it; otherwise send a minimal byte
179                        // (some systems require non-empty iov for ancillary data)
180                        let iov = if !remaining_data.is_empty() {
181                            [IoSlice::new(remaining_data)]
182                        } else {
183                            // Send a space byte that will be ignored by the receiver's JSON parser.
184                            // RFC 8259 defines space (0x20) as insignificant whitespace, and
185                            // serde_json's StreamDeserializer skips whitespace between values.
186                            [IoSlice::new(b" ")]
187                        };
188
189                        rustix::net::sendmsg(sockfd, &iov, &mut control, SendFlags::empty())
190                            .map_err(|e| to_io_error(e, "sendmsg"))
191                    } else if !remaining_data.is_empty() {
192                        // No FDs left, just send remaining data
193                        rustix::net::send(sockfd, remaining_data, SendFlags::empty())
194                            .map_err(|e| to_io_error(e, "send"))
195                    } else {
196                        // Nothing left to send
197                        Ok(0)
198                    }
199                })
200                .await;
201
202            match result {
203                Ok(sent) => {
204                    // Update bytes sent (but only count actual data bytes, not padding)
205                    if !remaining_data.is_empty() {
206                        bytes_sent += sent;
207                    }
208
209                    // Update FDs sent
210                    if !fds_batch.is_empty() {
211                        fds_sent += fds_batch.len();
212                        trace!(
213                            "Sent {} FDs (total: {}/{}) with {} bytes",
214                            fds_batch.len(),
215                            fds_sent,
216                            fds.len(),
217                            sent
218                        );
219                    }
220
221                    trace!(
222                        "Progress: {}/{} bytes, {}/{} FDs",
223                        bytes_sent,
224                        data.len(),
225                        fds_sent,
226                        fds.len()
227                    );
228                }
229                Err(e) if e.kind() == io::ErrorKind::InvalidInput && fds_batch.len() > 1 => {
230                    // EINVAL with multiple FDs likely means we exceeded the kernel's
231                    // SCM_MAX_FD limit. Reduce batch size and retry.
232                    let new_max = fds_batch.len() / 2;
233                    debug!(
234                        "sendmsg returned EINVAL with {} FDs, reducing batch size to {}",
235                        fds_batch.len(),
236                        new_max
237                    );
238                    current_max_fds = new_max;
239                    // Don't update bytes_sent or fds_sent - we'll retry this batch
240                    continue;
241                }
242                Err(e) => return Err(Error::Io(e)),
243            }
244        }
245
246        // If we discovered a lower limit, remember it for future sends
247        if current_max_fds < self.max_fds_per_sendmsg.get() {
248            debug!(
249                "Learned kernel FD limit: reducing max_fds_per_sendmsg from {} to {}",
250                self.max_fds_per_sendmsg, current_max_fds
251            );
252            // current_max_fds is at least 1 (we only reduce when fds_this_batch > 1)
253            self.max_fds_per_sendmsg =
254                NonZeroUsize::new(current_max_fds).expect("current_max_fds should be >= 1");
255        }
256
257        Ok(())
258    }
259}
260
261/// Receiver half of a Unix socket transport for receiving JSON-RPC messages.
262pub struct Receiver {
263    stream: Arc<TokioUnixStream>,
264    buffer: Vec<u8>,
265    fd_queue: VecDeque<OwnedFd>,
266    /// A fully parsed JSON message waiting for its FDs to arrive.
267    pending_message: Option<(serde_json::Value, usize)>,
268}
269
270impl Receiver {
271    /// Receive a message, returning an error on connection close.
272    ///
273    /// See also [`receive_opt`](Self::receive_opt) which returns `Ok(None)`
274    /// on connection close instead of an error.
275    pub async fn receive(&mut self) -> Result<MessageWithFds> {
276        loop {
277            if let Some(message) = self.try_parse_message()? {
278                return Ok(message);
279            }
280
281            if let Err(e) = self.read_more_data().await {
282                if matches!(e, Error::ConnectionClosed)
283                    && let Some((_, fd_count)) = self.pending_message.take()
284                {
285                    // Connection closed while waiting for FDs — per spec
286                    // Section 5, Step 4 this is a Mismatched Count error.
287                    return Err(Error::MismatchedCount {
288                        expected: fd_count,
289                        found: self.fd_queue.len(),
290                    });
291                }
292                return Err(e);
293            }
294        }
295    }
296
297    /// Receive a message, returning `Ok(None)` on connection close.
298    ///
299    /// This is a convenience method that converts `Error::ConnectionClosed`
300    /// to `Ok(None)`, which is useful for receiver loops:
301    ///
302    /// ```ignore
303    /// while let Some(msg) = receiver.receive_opt().await? {
304    ///     // handle message
305    /// }
306    /// ```
307    ///
308    /// See also [`receive`](Self::receive) which returns an error on
309    /// connection close.
310    pub async fn receive_opt(&mut self) -> Result<Option<MessageWithFds>> {
311        match self.receive().await {
312            Ok(msg) => Ok(Some(msg)),
313            Err(Error::ConnectionClosed) => Ok(None),
314            Err(e) => Err(e),
315        }
316    }
317
318    /// Build a `MessageWithFds` by draining `fd_count` FDs from the queue.
319    fn build_message(
320        fd_queue: &mut VecDeque<OwnedFd>,
321        value: serde_json::Value,
322        fd_count: usize,
323    ) -> Result<MessageWithFds> {
324        let fds: Vec<OwnedFd> = fd_queue.drain(..fd_count).collect();
325        let message = JsonRpcMessage::from_json_value(value)?;
326        Ok(MessageWithFds::new(message, fds))
327    }
328
329    fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
330        // Check if we have a pending message waiting for FDs.
331        // While a message is pending, all subsequent message parsing is
332        // blocked — even messages needing 0 FDs.  This preserves FIFO
333        // ordering on the Unix socket: FDs queued after the pending
334        // message's FDs belong to later messages and must not be
335        // consumed early.
336        if let Some((value, fd_count)) = self
337            .pending_message
338            .take_if(|(_, c)| self.fd_queue.len() >= *c)
339        {
340            return Ok(Some(Self::build_message(
341                &mut self.fd_queue,
342                value,
343                fd_count,
344            )?));
345        } else if let Some((_, fd_count)) = &self.pending_message {
346            // Not enough FDs yet.  Per the spec (Section 5, Step 4),
347            // if the buffer contains any non-whitespace byte the sender
348            // has started the next message before delivering all FDs for
349            // the current one — that is a fatal protocol violation.
350            if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
351                return Err(Error::MismatchedCount {
352                    expected: *fd_count,
353                    found: self.fd_queue.len(),
354                });
355            }
356            return Ok(None);
357        }
358
359        if self.buffer.is_empty() {
360            return Ok(None);
361        }
362
363        // Use streaming JSON parser to find message boundaries
364        let mut stream =
365            serde_json::Deserializer::from_slice(&self.buffer).into_iter::<serde_json::Value>();
366
367        match stream.next() {
368            Some(Ok(value)) => {
369                // Successfully parsed a complete JSON value
370                let bytes_consumed = stream.byte_offset();
371
372                trace!("Parsed message ({} bytes): {:?}", bytes_consumed, value);
373
374                // Drain the consumed bytes from the buffer
375                self.buffer.drain(..bytes_consumed);
376
377                // Read the fds count from the message and extract FDs
378                let fd_count = get_fd_count(&value);
379
380                if fd_count > self.fd_queue.len() {
381                    // FDs may arrive across multiple recvmsg() calls when the
382                    // sender batches them.  Buffer the parsed message and let
383                    // the receive() loop read more data.
384                    //
385                    // Per the spec (Section 5, Step 4), if the buffer already
386                    // contains non-whitespace bytes the sender has started the
387                    // next message before delivering all FDs — a fatal error.
388                    if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
389                        return Err(Error::MismatchedCount {
390                            expected: fd_count,
391                            found: self.fd_queue.len(),
392                        });
393                    }
394                    trace!(
395                        "Message expects {} FDs but only {} available, waiting for more",
396                        fd_count,
397                        self.fd_queue.len()
398                    );
399                    self.pending_message = Some((value, fd_count));
400                    return Ok(None);
401                }
402
403                Ok(Some(Self::build_message(
404                    &mut self.fd_queue,
405                    value,
406                    fd_count,
407                )?))
408            }
409            Some(Err(e)) if e.is_eof() => {
410                // Incomplete JSON - need more data
411                Ok(None)
412            }
413            Some(Err(e)) => {
414                // Actual parse error
415                Err(Error::Json(e))
416            }
417            None => {
418                // No more values (shouldn't happen with non-empty buffer, but handle it)
419                Ok(None)
420            }
421        }
422    }
423
424    async fn read_more_data(&mut self) -> Result<()> {
425        let mut data_buffer = [0u8; READ_BUFFER_SIZE];
426        let mut received_fds: Vec<OwnedFd> = Vec::new();
427
428        let bytes_read = self
429            .stream
430            .async_io(Interest::READABLE, || {
431                let sockfd = self.stream.as_fd();
432
433                let mut iov = [IoSliceMut::new(&mut data_buffer)];
434                let mut cmsg_space: [MaybeUninit<u8>;
435                    rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))] =
436                    [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))];
437                let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
438
439                let result = rustix::net::recvmsg(
440                    sockfd,
441                    &mut iov,
442                    &mut cmsg_buffer,
443                    RecvFlags::CMSG_CLOEXEC,
444                )
445                .map_err(|e| to_io_error(e, "recvmsg"))?;
446
447                // Extract file descriptors from control messages
448                for msg in cmsg_buffer.drain() {
449                    if let RecvAncillaryMessage::ScmRights(fds) = msg {
450                        received_fds.extend(fds);
451                    }
452                }
453
454                Ok(result.bytes)
455            })
456            .await
457            .map_err(Error::Io)?;
458
459        if bytes_read == 0 {
460            return Err(Error::ConnectionClosed);
461        }
462
463        self.buffer.extend_from_slice(&data_buffer[..bytes_read]);
464        self.fd_queue.extend(received_fds);
465
466        debug!(
467            "Read {} bytes, {} FDs in queue",
468            bytes_read,
469            self.fd_queue.len()
470        );
471        Ok(())
472    }
473}
474
475/// Convert a rustix error to an io::Error, preserving EAGAIN/EWOULDBLOCK for async_io
476fn to_io_error(e: rustix::io::Errno, operation: &str) -> io::Error {
477    // rustix::io::Errno can be converted to io::Error, which preserves the error kind
478    let io_err: io::Error = e.into();
479    if io_err.kind() == io::ErrorKind::WouldBlock {
480        io_err
481    } else {
482        io::Error::new(io_err.kind(), format!("{} failed: {}", operation, io_err))
483    }
484}