1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
//! Contains connection related API.
use core::fmt::Debug;
#[cfg(feature = "std")]
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use serde::Serialize;
use super::{BUFFER_SIZE, Call, Reply, socket::WriteHalf};
#[cfg(feature = "std")]
use std::os::fd::OwnedFd;
/// A connection.
///
/// The low-level API to send messages.
///
/// # Cancel safety
///
/// All async methods of this type are cancel safe unless explicitly stated otherwise in its
/// documentation.
#[derive(Debug)]
pub struct WriteConnection<Write: WriteHalf> {
pub(super) socket: Write,
pub(super) buffer: Vec<u8>,
pub(super) pos: usize,
id: usize,
#[cfg(feature = "std")]
pending_fds: VecDeque<MessageFds>,
// On macOS, SCM_RIGHTS does not properly hold a reference to the underlying file description
// when the sender closes the FD in the same process before the receiver calls `recvmsg`. The
// receiver ends up with a stale FD. We work around this by keeping sent FDs alive until the
// read half confirms it has called `recvmsg` for them via `drain_held_fds`.
#[cfg(all(feature = "std", target_os = "macos"))]
pub(super) held_fds: VecDeque<Vec<OwnedFd>>,
}
impl<Write: WriteHalf> WriteConnection<Write> {
/// Create a new connection.
pub(super) fn new(socket: Write, id: usize) -> Self {
Self {
socket,
id,
buffer: alloc::vec![0; BUFFER_SIZE],
pos: 0,
#[cfg(feature = "std")]
pending_fds: VecDeque::new(),
#[cfg(all(feature = "std", target_os = "macos"))]
held_fds: VecDeque::new(),
}
}
/// The unique identifier of the connection.
#[inline]
pub fn id(&self) -> usize {
self.id
}
/// Sends a method call.
///
/// The generic `Method` is the type of the method name and its input parameters. This should be
/// a type that can serialize itself to a complete method call message, i-e an object containing
/// `method` and `parameter` fields. This can be easily achieved using the `serde::Serialize`
/// derive:
///
/// ```rust
/// use serde::{Serialize, Deserialize};
/// use serde_prefix_all::prefix_all;
///
/// #[prefix_all("org.example.ftl.")]
/// #[derive(Debug, Serialize, Deserialize)]
/// #[serde(tag = "method", content = "parameters")]
/// enum MyMethods<'m> {
/// // The name needs to be the fully-qualified name of the error.
/// Alpha { param1: u32, param2: &'m str},
/// Bravo,
/// Charlie { param1: &'m str },
/// }
/// ```
///
/// The `fds` parameter contains file descriptors to send along with the call.
pub async fn send_call<Method>(
&mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
Method: Serialize + Debug,
{
trace!("connection {}: sending call: {:?}", self.id, call);
#[cfg(feature = "std")]
{
self.write(call, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write(call).await
}
}
/// Send a reply over the socket.
///
/// The generic parameter `Params` is the type of the successful reply. This should be a type
/// that can serialize itself as the `parameters` field of the reply.
///
/// The `fds` parameter contains file descriptors to send along with the reply.
pub async fn send_reply<Params>(
&mut self,
reply: &Reply<Params>,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
Params: Serialize + Debug,
{
trace!("connection {}: sending reply: {:?}", self.id, reply);
#[cfg(feature = "std")]
{
self.write(reply, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write(reply).await
}
}
/// Send an error reply over the socket.
///
/// The generic parameter `ReplyError` is the type of the error reply. This should be a type
/// that can serialize itself to the whole reply object, containing `error` and `parameter`
/// fields. This can be easily achieved using the `serde::Serialize` derive (See the code
/// snippet in [`super::ReadConnection::receive_reply`] documentation for an example).
///
/// The `fds` parameter contains file descriptors to send along with the error.
pub async fn send_error<ReplyError>(
&mut self,
error: &ReplyError,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
ReplyError: Serialize + Debug,
{
trace!("connection {}: sending error: {:?}", self.id, error);
#[cfg(feature = "std")]
{
self.write(error, fds).await
}
#[cfg(not(feature = "std"))]
{
self.write(error).await
}
}
/// Enqueue a call to be sent over the socket.
///
/// Similar to [`WriteConnection::send_call`], except that the call is not sent immediately but
/// enqueued for later sending. This is useful when you want to send multiple calls in a
/// batch.
///
/// The `fds` parameter contains file descriptors to send along with the call.
pub fn enqueue_call<Method>(
&mut self,
call: &Call<Method>,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
Method: Serialize + Debug,
{
trace!("connection {}: enqueuing call: {:?}", self.id, call);
#[cfg(feature = "std")]
{
self.enqueue(call, fds)
}
#[cfg(not(feature = "std"))]
{
self.enqueue(call)
}
}
/// Send out the enqueued calls.
pub async fn flush(&mut self) -> crate::Result<()> {
if self.pos == 0 {
return Ok(());
}
#[allow(unused_mut)]
let mut sent_pos = 0;
#[cfg(feature = "std")]
{
// While there are FDs, send one message at a time.
while !self.pending_fds.is_empty() {
// Get the first FD entry.
let pending = self.pending_fds.front().unwrap();
let fd_offset = pending.offset;
let msg_len = pending.len;
// If there are bytes before the FD message, send them first without FDs.
if sent_pos < fd_offset {
trace!(
"connection {}: flushing {} bytes before FD message",
self.id,
fd_offset - sent_pos
);
self.socket
.write(&self.buffer[sent_pos..fd_offset], &[] as &[OwnedFd])
.await?;
}
// Send this message with its FDs.
let msg_end = fd_offset + msg_len;
let MessageFds {
fds,
offset: _,
len: _,
} = self.pending_fds.pop_front().unwrap();
trace!(
"connection {}: flushing {} bytes with {} FDs",
self.id,
msg_len,
fds.len()
);
self.socket
.write(&self.buffer[fd_offset..msg_end], &fds)
.await?;
sent_pos = msg_end;
// On macOS, keep sent FDs alive until the read half confirms receipt via
// `recvmsg`. See the comment on the `held_fds` field for details.
#[cfg(target_os = "macos")]
self.held_fds.push_back(fds);
}
}
// No more FDs, send all remaining bytes at once.
if sent_pos < self.pos {
trace!(
"connection {}: flushing {} bytes",
self.id,
self.pos - sent_pos
);
#[cfg(feature = "std")]
{
self.socket
.write(&self.buffer[sent_pos..self.pos], &[] as &[OwnedFd])
.await?;
}
#[cfg(not(feature = "std"))]
{
self.socket.write(&self.buffer[sent_pos..self.pos]).await?;
}
}
self.pos = 0;
Ok(())
}
/// The underlying write half of the socket.
pub fn write_half(&self) -> &Write {
&self.socket
}
pub(super) async fn write<T>(
&mut self,
value: &T,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
T: Serialize + ?Sized + Debug,
{
#[cfg(feature = "std")]
{
self.enqueue(value, fds)?;
}
#[cfg(not(feature = "std"))]
{
self.enqueue(value)?;
}
self.flush().await
}
pub(super) fn enqueue<T>(
&mut self,
value: &T,
#[cfg(feature = "std")] fds: Vec<OwnedFd>,
) -> crate::Result<()>
where
T: Serialize + ?Sized + Debug,
{
#[cfg(feature = "std")]
let start_pos = self.pos;
let len = loop {
match crate::json_ser::to_slice(value, &mut self.buffer[self.pos..]) {
Ok(len) => break len,
Err(crate::json_ser::Error::BufferTooSmall) => {
// Buffer too small, grow it and retry
self.grow_buffer()?;
}
Err(crate::json_ser::Error::KeyMustBeAString) => {
// Actual serialization error
// Convert to serde_json::Error for public API
return Err(crate::Error::Json(serde::ser::Error::custom(
"key must be a string",
)));
}
}
};
// Add null terminator after this message.
if self.pos + len == self.buffer.len() {
self.grow_buffer()?;
}
self.buffer[self.pos + len] = b'\0';
self.pos += len + 1;
// Store FDs with message offset and length.
#[cfg(feature = "std")]
if !fds.is_empty() {
self.pending_fds.push_back(MessageFds {
offset: start_pos,
len: len + 1, // Include null terminator.
fds,
});
}
Ok(())
}
fn grow_buffer(&mut self) -> crate::Result<()> {
if self.buffer.len() >= super::MAX_BUFFER_SIZE {
return Err(crate::Error::BufferOverflow);
}
self.buffer.extend_from_slice(&[0; BUFFER_SIZE]);
Ok(())
}
}
/// Information about file descriptors pending to be sent with a message.
#[cfg(feature = "std")]
#[derive(Debug)]
struct MessageFds {
/// File descriptors to send with this message.
fds: Vec<OwnedFd>,
/// Buffer offset where the message starts.
offset: usize,
/// Length of the message including the null terminator.
len: usize,
}