jsonrpc_fdpass/transport.rs
1use crate::error::{Error, Result};
2use crate::message::{JsonRpcMessage, JsonRpcNotification, MessageWithFds, get_fd_count};
3use rustix::fd::AsFd;
4use rustix::net::{
5 RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
6 SendAncillaryMessage, SendFlags,
7};
8use serde::Serialize;
9use std::collections::VecDeque;
10use std::io::{self, IoSlice, IoSliceMut};
11use std::mem::MaybeUninit;
12use std::num::NonZeroUsize;
13use std::os::unix::io::OwnedFd;
14use std::sync::Arc;
15use tokio::io::Interest;
16use tokio::net::UnixStream as TokioUnixStream;
17use tracing::{debug, trace};
18
19/// Default maximum number of file descriptors per sendmsg() call.
20///
21/// Platform limits for SCM_RIGHTS vary (e.g., ~253 on Linux, ~512 on macOS).
22/// We start with an optimistic value; if sendmsg() fails with EINVAL, the
23/// batch size is automatically reduced and the send is retried.
24pub const DEFAULT_MAX_FDS_PER_SENDMSG: NonZeroUsize = NonZeroUsize::new(500).unwrap();
25
26/// Maximum FDs to expect in a single recvmsg() call.
27/// Must be at least as large as the largest platform limit (~512 on macOS).
28const MAX_FDS_PER_RECVMSG: usize = 512;
29
30/// Read buffer size for incoming data.
31const READ_BUFFER_SIZE: usize = 4096;
32
33/// Transport layer for Unix socket communication with file descriptor passing.
34pub struct UnixSocketTransport {
35 stream: TokioUnixStream,
36}
37
38impl UnixSocketTransport {
39 /// Create a new transport from an existing Unix stream.
40 pub fn new(stream: TokioUnixStream) -> Self {
41 Self { stream }
42 }
43
44 /// Split the transport into separate sender and receiver halves.
45 pub fn split(self) -> (Sender, Receiver) {
46 let stream = Arc::new(self.stream);
47
48 (
49 Sender {
50 stream: Arc::clone(&stream),
51 pretty: false,
52 max_fds_per_sendmsg: DEFAULT_MAX_FDS_PER_SENDMSG,
53 },
54 Receiver {
55 stream,
56 buffer: Vec::new(),
57 fd_queue: VecDeque::new(),
58 pending_message: None,
59 },
60 )
61 }
62}
63
64/// Sender half of a Unix socket transport for sending JSON-RPC messages.
65pub struct Sender {
66 stream: Arc<TokioUnixStream>,
67 pretty: bool,
68 /// Maximum FDs to send per sendmsg() call. Configurable for testing.
69 max_fds_per_sendmsg: NonZeroUsize,
70}
71
72impl Sender {
73 /// Enable or disable pretty-printed JSON output.
74 ///
75 /// When enabled, messages are serialized with indentation and newlines.
76 /// This is useful for debugging or when interoperating with tools that
77 /// expect human-readable JSON.
78 pub fn set_pretty(&mut self, pretty: bool) {
79 self.pretty = pretty;
80 }
81
82 /// Set the maximum number of file descriptors to send per sendmsg() call.
83 ///
84 /// This is primarily useful for testing FD batching behavior. The default
85 /// value ([`DEFAULT_MAX_FDS_PER_SENDMSG`]) is optimistic and may exceed
86 /// some platform limits; if sendmsg() returns `EINVAL`, the batch size is
87 /// automatically reduced and the send is retried.
88 pub fn set_max_fds_per_sendmsg(&mut self, max_fds: NonZeroUsize) {
89 self.max_fds_per_sendmsg = max_fds;
90 }
91
92 /// Send a notification without file descriptors.
93 ///
94 /// This is a convenience method that serializes the params and constructs
95 /// the notification message automatically.
96 pub async fn notify<P: Serialize>(&mut self, method: &str, params: P) -> Result<()> {
97 self.notify_with_fds(method, params, Vec::new()).await
98 }
99
100 /// Send a notification with file descriptors.
101 ///
102 /// This is a convenience method that serializes the params and constructs
103 /// the notification message automatically.
104 pub async fn notify_with_fds<P: Serialize>(
105 &mut self,
106 method: &str,
107 params: P,
108 fds: Vec<OwnedFd>,
109 ) -> Result<()> {
110 let params_value = serde_json::to_value(params)?;
111 let params_opt = if params_value.is_null() {
112 None
113 } else {
114 Some(params_value)
115 };
116 let notification = JsonRpcNotification::new(method.to_string(), params_opt);
117 let message = JsonRpcMessage::Notification(notification);
118 let message_with_fds = MessageWithFds::new(message, fds);
119 self.send(message_with_fds).await
120 }
121
122 /// Send a JSON-RPC message with optional file descriptors.
123 pub async fn send(&mut self, message_with_fds: MessageWithFds) -> Result<()> {
124 let serialized = if self.pretty {
125 message_with_fds.serialize_pretty()?
126 } else {
127 message_with_fds.serialize()?
128 };
129 let data = serialized.into_bytes();
130
131 trace!(
132 "Sending message: {} with {} FDs",
133 String::from_utf8_lossy(&data).trim(),
134 message_with_fds.file_descriptors.len()
135 );
136
137 let fds = message_with_fds.file_descriptors;
138
139 // Track how many bytes and FDs we've sent so far
140 let mut bytes_sent = 0usize;
141 let mut fds_sent = 0usize;
142
143 // Current max FDs per batch - may be reduced if we hit EINVAL
144 let mut current_max_fds = self.max_fds_per_sendmsg.get();
145
146 // Send data with FDs in batches. Each sendmsg can only handle a limited number of FDs.
147 // We send FDs with the data chunks, and any remaining FDs after all data is sent.
148 while bytes_sent < data.len() || fds_sent < fds.len() {
149 let remaining_data = &data[bytes_sent..];
150 let remaining_fds = &fds[fds_sent..];
151
152 // Determine how many FDs to send in this batch (up to current_max_fds)
153 let fds_batch = remaining_fds
154 .get(..current_max_fds)
155 .unwrap_or(remaining_fds);
156
157 let result = self
158 .stream
159 .async_io(Interest::WRITABLE, || {
160 let sockfd = self.stream.as_fd();
161
162 if !fds_batch.is_empty() {
163 // Send with FDs using sendmsg with ancillary data
164 let borrowed_fds: Vec<_> = fds_batch.iter().map(|fd| fd.as_fd()).collect();
165
166 let mut buffer: [MaybeUninit<u8>;
167 rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))] =
168 [MaybeUninit::uninit();
169 rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))];
170 let mut control = SendAncillaryBuffer::new(&mut buffer);
171
172 if !control.push(SendAncillaryMessage::ScmRights(&borrowed_fds)) {
173 return Err(io::Error::other(
174 "Failed to add file descriptors to control message",
175 ));
176 }
177
178 // If we have data to send, include it; otherwise send a minimal byte
179 // (some systems require non-empty iov for ancillary data)
180 let iov = if !remaining_data.is_empty() {
181 [IoSlice::new(remaining_data)]
182 } else {
183 // Send a space byte that will be ignored by the receiver's JSON parser.
184 // RFC 8259 defines space (0x20) as insignificant whitespace, and
185 // serde_json's StreamDeserializer skips whitespace between values.
186 [IoSlice::new(b" ")]
187 };
188
189 rustix::net::sendmsg(sockfd, &iov, &mut control, SendFlags::empty())
190 .map_err(|e| to_io_error(e, "sendmsg"))
191 } else if !remaining_data.is_empty() {
192 // No FDs left, just send remaining data
193 rustix::net::send(sockfd, remaining_data, SendFlags::empty())
194 .map_err(|e| to_io_error(e, "send"))
195 } else {
196 // Nothing left to send
197 Ok(0)
198 }
199 })
200 .await;
201
202 match result {
203 Ok(sent) => {
204 // Update bytes sent (but only count actual data bytes, not padding)
205 if !remaining_data.is_empty() {
206 bytes_sent += sent;
207 }
208
209 // Update FDs sent
210 if !fds_batch.is_empty() {
211 fds_sent += fds_batch.len();
212 trace!(
213 "Sent {} FDs (total: {}/{}) with {} bytes",
214 fds_batch.len(),
215 fds_sent,
216 fds.len(),
217 sent
218 );
219 }
220
221 trace!(
222 "Progress: {}/{} bytes, {}/{} FDs",
223 bytes_sent,
224 data.len(),
225 fds_sent,
226 fds.len()
227 );
228 }
229 Err(e) if e.kind() == io::ErrorKind::InvalidInput && fds_batch.len() > 1 => {
230 // EINVAL with multiple FDs likely means we exceeded the kernel's
231 // SCM_MAX_FD limit. Reduce batch size and retry.
232 let new_max = fds_batch.len() / 2;
233 debug!(
234 "sendmsg returned EINVAL with {} FDs, reducing batch size to {}",
235 fds_batch.len(),
236 new_max
237 );
238 current_max_fds = new_max;
239 // Don't update bytes_sent or fds_sent - we'll retry this batch
240 continue;
241 }
242 Err(e) => return Err(Error::Io(e)),
243 }
244 }
245
246 // If we discovered a lower limit, remember it for future sends
247 if current_max_fds < self.max_fds_per_sendmsg.get() {
248 debug!(
249 "Learned kernel FD limit: reducing max_fds_per_sendmsg from {} to {}",
250 self.max_fds_per_sendmsg, current_max_fds
251 );
252 // current_max_fds is at least 1 (we only reduce when fds_this_batch > 1)
253 self.max_fds_per_sendmsg =
254 NonZeroUsize::new(current_max_fds).expect("current_max_fds should be >= 1");
255 }
256
257 Ok(())
258 }
259}
260
261/// Receiver half of a Unix socket transport for receiving JSON-RPC messages.
262pub struct Receiver {
263 stream: Arc<TokioUnixStream>,
264 buffer: Vec<u8>,
265 fd_queue: VecDeque<OwnedFd>,
266 /// A fully parsed JSON message waiting for its FDs to arrive.
267 pending_message: Option<(serde_json::Value, usize)>,
268}
269
270impl Receiver {
271 /// Receive a message, returning an error on connection close.
272 ///
273 /// See also [`receive_opt`](Self::receive_opt) which returns `Ok(None)`
274 /// on connection close instead of an error.
275 pub async fn receive(&mut self) -> Result<MessageWithFds> {
276 loop {
277 if let Some(message) = self.try_parse_message()? {
278 return Ok(message);
279 }
280
281 if let Err(e) = self.read_more_data().await {
282 if matches!(e, Error::ConnectionClosed)
283 && let Some((_, fd_count)) = self.pending_message.take()
284 {
285 // Connection closed while waiting for FDs — per spec
286 // Section 5, Step 4 this is a Mismatched Count error.
287 return Err(Error::MismatchedCount {
288 expected: fd_count,
289 found: self.fd_queue.len(),
290 });
291 }
292 return Err(e);
293 }
294 }
295 }
296
297 /// Receive a message, returning `Ok(None)` on connection close.
298 ///
299 /// This is a convenience method that converts `Error::ConnectionClosed`
300 /// to `Ok(None)`, which is useful for receiver loops:
301 ///
302 /// ```ignore
303 /// while let Some(msg) = receiver.receive_opt().await? {
304 /// // handle message
305 /// }
306 /// ```
307 ///
308 /// See also [`receive`](Self::receive) which returns an error on
309 /// connection close.
310 pub async fn receive_opt(&mut self) -> Result<Option<MessageWithFds>> {
311 match self.receive().await {
312 Ok(msg) => Ok(Some(msg)),
313 Err(Error::ConnectionClosed) => Ok(None),
314 Err(e) => Err(e),
315 }
316 }
317
318 /// Build a `MessageWithFds` by draining `fd_count` FDs from the queue.
319 fn build_message(
320 fd_queue: &mut VecDeque<OwnedFd>,
321 value: serde_json::Value,
322 fd_count: usize,
323 ) -> Result<MessageWithFds> {
324 let fds: Vec<OwnedFd> = fd_queue.drain(..fd_count).collect();
325 let message = JsonRpcMessage::from_json_value(value)?;
326 Ok(MessageWithFds::new(message, fds))
327 }
328
329 fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
330 // Check if we have a pending message waiting for FDs.
331 // While a message is pending, all subsequent message parsing is
332 // blocked — even messages needing 0 FDs. This preserves FIFO
333 // ordering on the Unix socket: FDs queued after the pending
334 // message's FDs belong to later messages and must not be
335 // consumed early.
336 if let Some((value, fd_count)) = self
337 .pending_message
338 .take_if(|(_, c)| self.fd_queue.len() >= *c)
339 {
340 return Ok(Some(Self::build_message(
341 &mut self.fd_queue,
342 value,
343 fd_count,
344 )?));
345 } else if let Some((_, fd_count)) = &self.pending_message {
346 // Not enough FDs yet. Per the spec (Section 5, Step 4),
347 // if the buffer contains any non-whitespace byte the sender
348 // has started the next message before delivering all FDs for
349 // the current one — that is a fatal protocol violation.
350 if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
351 return Err(Error::MismatchedCount {
352 expected: *fd_count,
353 found: self.fd_queue.len(),
354 });
355 }
356 return Ok(None);
357 }
358
359 if self.buffer.is_empty() {
360 return Ok(None);
361 }
362
363 // Use streaming JSON parser to find message boundaries
364 let mut stream =
365 serde_json::Deserializer::from_slice(&self.buffer).into_iter::<serde_json::Value>();
366
367 match stream.next() {
368 Some(Ok(value)) => {
369 // Successfully parsed a complete JSON value
370 let bytes_consumed = stream.byte_offset();
371
372 trace!("Parsed message ({} bytes): {:?}", bytes_consumed, value);
373
374 // Drain the consumed bytes from the buffer
375 self.buffer.drain(..bytes_consumed);
376
377 // Read the fds count from the message and extract FDs
378 let fd_count = get_fd_count(&value);
379
380 if fd_count > self.fd_queue.len() {
381 // FDs may arrive across multiple recvmsg() calls when the
382 // sender batches them. Buffer the parsed message and let
383 // the receive() loop read more data.
384 //
385 // Per the spec (Section 5, Step 4), if the buffer already
386 // contains non-whitespace bytes the sender has started the
387 // next message before delivering all FDs — a fatal error.
388 if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
389 return Err(Error::MismatchedCount {
390 expected: fd_count,
391 found: self.fd_queue.len(),
392 });
393 }
394 trace!(
395 "Message expects {} FDs but only {} available, waiting for more",
396 fd_count,
397 self.fd_queue.len()
398 );
399 self.pending_message = Some((value, fd_count));
400 return Ok(None);
401 }
402
403 Ok(Some(Self::build_message(
404 &mut self.fd_queue,
405 value,
406 fd_count,
407 )?))
408 }
409 Some(Err(e)) if e.is_eof() => {
410 // Incomplete JSON - need more data
411 Ok(None)
412 }
413 Some(Err(e)) => {
414 // Actual parse error
415 Err(Error::Json(e))
416 }
417 None => {
418 // No more values (shouldn't happen with non-empty buffer, but handle it)
419 Ok(None)
420 }
421 }
422 }
423
424 async fn read_more_data(&mut self) -> Result<()> {
425 let mut data_buffer = [0u8; READ_BUFFER_SIZE];
426 let mut received_fds: Vec<OwnedFd> = Vec::new();
427
428 let bytes_read = self
429 .stream
430 .async_io(Interest::READABLE, || {
431 let sockfd = self.stream.as_fd();
432
433 let mut iov = [IoSliceMut::new(&mut data_buffer)];
434 let mut cmsg_space: [MaybeUninit<u8>;
435 rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))] =
436 [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))];
437 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
438
439 let result = rustix::net::recvmsg(
440 sockfd,
441 &mut iov,
442 &mut cmsg_buffer,
443 RecvFlags::CMSG_CLOEXEC,
444 )
445 .map_err(|e| to_io_error(e, "recvmsg"))?;
446
447 // Extract file descriptors from control messages
448 for msg in cmsg_buffer.drain() {
449 if let RecvAncillaryMessage::ScmRights(fds) = msg {
450 received_fds.extend(fds);
451 }
452 }
453
454 Ok(result.bytes)
455 })
456 .await
457 .map_err(Error::Io)?;
458
459 if bytes_read == 0 {
460 return Err(Error::ConnectionClosed);
461 }
462
463 self.buffer.extend_from_slice(&data_buffer[..bytes_read]);
464 self.fd_queue.extend(received_fds);
465
466 debug!(
467 "Read {} bytes, {} FDs in queue",
468 bytes_read,
469 self.fd_queue.len()
470 );
471 Ok(())
472 }
473}
474
475/// Convert a rustix error to an io::Error, preserving EAGAIN/EWOULDBLOCK for async_io
476fn to_io_error(e: rustix::io::Errno, operation: &str) -> io::Error {
477 // rustix::io::Errno can be converted to io::Error, which preserves the error kind
478 let io_err: io::Error = e.into();
479 if io_err.kind() == io::ErrorKind::WouldBlock {
480 io_err
481 } else {
482 io::Error::new(io_err.kind(), format!("{} failed: {}", operation, io_err))
483 }
484}