1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
use nix::{
errno::Errno,
sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrIn, TlsGetRecordType},
};
use num_enum::FromPrimitive;
use std::{
io::{self, IoSliceMut},
os::unix::prelude::AsRawFd,
pin::Pin,
task,
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
// A wrapper around `IO` that sends a `close_notify` when shut down or dropped.
pin_project_lite::pin_project! {
pub struct KtlsStream<IO>
where
IO: AsRawFd
{
#[pin]
inner: IO,
write_closed: bool,
read_closed: bool,
drained: Option<(usize, Vec<u8>)>,
}
}
impl<IO> KtlsStream<IO>
where
IO: AsRawFd,
{
pub fn new(inner: IO, drained: Option<Vec<u8>>) -> Self {
Self {
inner,
write_closed: false,
read_closed: false,
drained: drained.map(|drained| (0, drained)),
}
}
/// Return the drained data + the original I/O
pub fn into_raw(self) -> (Option<Vec<u8>>, IO) {
(self.drained.map(|(_, drained)| drained), self.inner)
}
/// Returns a reference to the original I/O
pub fn get_ref(&self) -> &IO {
&self.inner
}
/// Returns a mut reference to the original I/O
pub fn get_mut(&mut self) -> &mut IO {
&mut self.inner
}
}
#[derive(Debug, PartialEq, Clone, Copy, num_enum::FromPrimitive)]
#[repr(u8)]
enum TlsAlertLevel {
Warning = 1,
Fatal = 2,
#[num_enum(catch_all)]
Other(u8),
}
#[derive(Debug, PartialEq, Clone, Copy, num_enum::FromPrimitive)]
#[repr(u8)]
enum TlsAlertDescription {
CloseNotify = 0,
#[num_enum(catch_all)]
Other(u8),
}
impl<IO> AsyncRead for KtlsStream<IO>
where
IO: AsRawFd + AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
tracing::trace!(buf.remaining = %buf.remaining(), "KtlsStream::poll_read");
if self.read_closed {
return task::Poll::Ready(Ok(()));
}
if buf.remaining() == 0 {
return task::Poll::Ready(Ok(()));
}
let mut this = self.project();
if let Some((drain_index, drained)) = this.drained.as_mut() {
let drained = &drained[*drain_index..];
let len = std::cmp::min(buf.remaining(), drained.len());
tracing::trace!(%len, "KtlsStream::poll_read, can take from drain");
buf.put_slice(&drained[..len]);
*drain_index += len;
if *drain_index >= drained.len() {
tracing::trace!("KtlsStream::poll_read, done draining");
*this.drained = None;
}
cx.waker().wake_by_ref();
tracing::trace!("KtlsStream::poll_read, returning after drain");
return task::Poll::Ready(Ok(()));
}
let read_res = this.inner.as_mut().poll_read(cx, buf);
if let task::Poll::Ready(Err(e)) = &read_res {
// 5 is a generic "input/output error", it happens when
// using poll_read on a kTLS socket that just received
// a control message
if let Some(5) = e.raw_os_error() {
// could be a control message, let's check
let fd = this.inner.as_raw_fd();
// XXX: recvmsg wants a `&mut Vec<u8>` so it's able to resize it
// I guess? Or so there's a clear separation between uninitialized
// and initialized? We could probably get read of that heap alloc, idk.
// let mut cmsgspace =
// [0u8; unsafe { libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _ }];
let mut cmsgspace =
Vec::with_capacity(unsafe { libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _ });
let mut iov = [IoSliceMut::new(buf.initialize_unfilled())];
let flags = MsgFlags::empty();
let r = recvmsg::<SockaddrIn>(fd, &mut iov, Some(&mut cmsgspace), flags);
let r = match r {
Ok(r) => r,
Err(Errno::EAGAIN) => {
unreachable!("expected a control message, got EAGAIN")
}
Err(e) => {
// ok I guess it really failed then
tracing::trace!(?e, "recvmsg failed");
return Err(e.into()).into();
}
};
let cmsg = r
.cmsgs()?
.next()
.expect("we should've received exactly one control message");
let record_type = match cmsg {
ControlMessageOwned::TlsGetRecordType(t) => t,
_ => panic!("unexpected cmsg type: {cmsg:#?}"),
};
match record_type {
TlsGetRecordType::ChangeCipherSpec => {
panic!("change_cipher_spec isn't supported by the ktls crate")
}
TlsGetRecordType::Alert => {
// the alert level and description are in iovs
let iov = r.iovs().next().expect("expected data in iovs");
let (level, description) = match iov {
[] => {
// we have an early return case for that
unreachable!();
}
&[level] => {
// https://github.com/facebookincubator/fizz/blob/fff6d9d49d3c554ab66b58822d1e1fe93e8d80f2/fizz/experimental/ktls/AsyncKTLSSocket.cpp#L144
//
// Since all alerts (even warning-level alerts)
// signal the abort of a TLS session, we do not
// need to worry about additional application
// data.
//
// If we only have half the alert (because the
// user passed a buffer of size 1), just assume
// it's a close_notify
(
TlsAlertLevel::from_primitive(level),
TlsAlertDescription::CloseNotify,
)
}
&[level, description] => (
TlsAlertLevel::from_primitive(level),
TlsAlertDescription::from_primitive(description),
),
_ => {
unreachable!("TLS alerts are exactly 2 bytes, your kTLS is misbehaving");
}
};
match (level, description) {
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
// alerts we should handle are ones with fatal level or a
// close_notify
(_, TlsAlertDescription::CloseNotify) | (TlsAlertLevel::Fatal, _) => {
tracing::trace!(?level, ?description, "got TLS alert");
*this.read_closed = true;
*this.write_closed = true;
if let Err(e) = crate::ffi::send_close_notify(this.inner.as_raw_fd()) {
return Err(e).into();
}
// the file descriptor will be closed when the stream is dropped,
// we already protect against writes-after-close_notify through
// the write_closed flag
return task::Poll::Ready(Ok(()));
}
_ => {
// we got something we probably can't handle
}
}
return task::Poll::Ready(Ok(()));
}
TlsGetRecordType::Handshake => {
// TODO: this is where we receive TLS 1.3 resumption tickets,
// should those be stored anywhere? I'm not even sure what
// format they have at this point
tracing::trace!("ignoring handshake message (probably a resumption ticket)");
}
TlsGetRecordType::ApplicationData => {
unreachable!("received TLS application in recvmsg, this is supposed to happen in the poll_read codepath")
}
TlsGetRecordType::Unknown(t) => {
// just ignore the record?
tracing::trace!("received record_type {t:#?}");
}
_ => {
tracing::trace!("received unsupported record type");
}
};
// FIXME: this is hacky, but can we do better?
// after we handled (..ignored) the control message, we don't
// know whether the socket is still ready to be read or not.
//
// we could try looping (tricky code structure), but we can't,
// for example, just call `poll_read`, which might fail not
// with EAGAIN/EWOULDBLOCK, but because _another_ control
// message is available.
cx.waker().wake_by_ref();
return task::Poll::Pending;
}
}
read_res
}
}
impl<IO> AsyncWrite for KtlsStream<IO>
where
IO: AsRawFd + AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
if self.write_closed {
return task::Poll::Ready(Ok(0));
}
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
let this = self.project();
if !*this.write_closed {
// they didn't hang up on us, we're nicely being asked to shut down,
// let's send a close_notify (and not wait for them to send it back)
*this.write_closed = true;
if let Err(e) = crate::ffi::send_close_notify(this.inner.as_raw_fd()) {
return Err(e).into();
}
}
// this ends up closing the inner file descriptor no matter what
this.inner.poll_shutdown(cx)
}
}
impl<IO> AsRawFd for KtlsStream<IO>
where
IO: AsRawFd,
{
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
self.inner.as_raw_fd()
}
}