Skip to main content

miku_ktls/
stream.rs

1pub mod cork;
2
3use std::{
4    io::{self, IoSliceMut},
5    os::unix::io::AsRawFd,
6    pin::Pin,
7    task,
8};
9
10use nix::{
11    errno::Errno,
12    sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrIn, TlsGetRecordType},
13};
14use num_enum::FromPrimitive;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17use crate::AsyncReadReady;
18
19// A wrapper around `IO` that sends a `close_notify` when shut down or dropped.
20pin_project_lite::pin_project! {
21    pub struct KtlsStream<IO>
22    where
23        IO: AsRawFd
24    {
25        #[pin]
26        inner: IO,
27        write_closed: bool,
28        read_closed: bool,
29        drained: Option<(usize, Vec<u8>)>,
30    }
31}
32
33impl<IO> KtlsStream<IO>
34where
35    IO: AsRawFd,
36{
37    pub fn new(inner: IO, drained: Option<Vec<u8>>) -> Self {
38        Self {
39            inner,
40            write_closed: false,
41            read_closed: false,
42            drained: drained.map(|drained| (0, drained)),
43        }
44    }
45
46    /// Return the drained data + the original I/O
47    pub fn into_raw(self) -> (Option<Vec<u8>>, IO) {
48        (self.drained.map(|(_, drained)| drained), self.inner)
49    }
50
51    /// Returns a reference to the original I/O
52    pub fn get_ref(&self) -> &IO {
53        &self.inner
54    }
55
56    /// Returns a mut reference to the original I/O
57    pub fn get_mut(&mut self) -> &mut IO {
58        &mut self.inner
59    }
60}
61
62#[derive(Debug, PartialEq, Clone, Copy, num_enum::FromPrimitive)]
63#[repr(u8)]
64enum TlsAlertLevel {
65    Warning = 1,
66    Fatal = 2,
67    #[num_enum(catch_all)]
68    Other(u8),
69}
70
71#[derive(Debug, PartialEq, Clone, Copy, num_enum::FromPrimitive)]
72#[repr(u8)]
73enum TlsAlertDescription {
74    CloseNotify = 0,
75    #[num_enum(catch_all)]
76    Other(u8),
77}
78
79impl<'a, IO> AsyncRead for KtlsStream<IO>
80where
81    IO: AsRawFd + AsyncRead + AsyncReadReady<'a>,
82{
83    fn poll_read(
84        self: Pin<&mut Self>,
85        cx: &mut task::Context<'_>,
86        buf: &mut ReadBuf<'_>,
87    ) -> task::Poll<io::Result<()>> {
88        tracing::trace!(buf.remaining = %buf.remaining(), "KtlsStream::poll_read");
89
90        if self.read_closed {
91            return task::Poll::Ready(Ok(()));
92        }
93
94        if buf.remaining() == 0 {
95            return task::Poll::Ready(Ok(()));
96        }
97
98        let mut this = self.project();
99
100        if let Some((drain_index, drained)) = this.drained.as_mut() {
101            let drained = &drained[*drain_index..];
102            let len = std::cmp::min(buf.remaining(), drained.len());
103
104            tracing::trace!(%len, "KtlsStream::poll_read, can take from drain");
105            buf.put_slice(&drained[..len]);
106
107            *drain_index += len;
108            if *drain_index >= drained.len() {
109                tracing::trace!("KtlsStream::poll_read, done draining");
110                *this.drained = None;
111            }
112            cx.waker().wake_by_ref();
113
114            tracing::trace!("KtlsStream::poll_read, returning after drain");
115            return task::Poll::Ready(Ok(()));
116        }
117
118        let read_res = this.inner.as_mut().poll_read(cx, buf);
119        if let task::Poll::Ready(Err(e)) = &read_res {
120            // 5 is a generic "input/output error", it happens when
121            // using poll_read on a kTLS socket that just received
122            // a control message
123            if let Some(5) = e.raw_os_error() {
124                // could be a control message, let's check
125                let fd = this.inner.as_raw_fd();
126
127                // XXX: recvmsg wants a `&mut Vec<u8>` so it's able to resize it
128                // I guess? Or so there's a clear separation between uninitialized
129                // and initialized? We could probably get read of that heap alloc, idk.
130
131                // let mut cmsgspace =
132                //     [0u8; unsafe { libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _ }];
133                let mut cmsgspace = Vec::with_capacity(unsafe {
134                    libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _
135                });
136
137                let mut iov = [IoSliceMut::new(buf.initialize_unfilled())];
138                let flags = MsgFlags::empty();
139
140                let r = recvmsg::<SockaddrIn>(fd, &mut iov, Some(&mut cmsgspace), flags);
141                let r = match r {
142                    Ok(r) => r,
143                    Err(Errno::EAGAIN) => {
144                        unreachable!("expected a control message, got EAGAIN")
145                    }
146                    Err(e) => {
147                        // ok I guess it really failed then
148                        tracing::trace!(?e, "recvmsg failed");
149                        return Err(e.into()).into();
150                    }
151                };
152                let cmsg = r
153                    .cmsgs()?
154                    .next()
155                    .expect("we should've received exactly one control message");
156
157                let record_type = match cmsg {
158                    ControlMessageOwned::TlsGetRecordType(t) => t,
159                    _ => panic!("unexpected cmsg type: {cmsg:#?}"),
160                };
161
162                match record_type {
163                    TlsGetRecordType::ChangeCipherSpec => {
164                        panic!("change_cipher_spec isn't supported by the ktls crate")
165                    }
166                    TlsGetRecordType::Alert => {
167                        // the alert level and description are in iovs
168                        let iov = r.iovs().next().expect("expected data in iovs");
169
170                        let (level, description) = match iov {
171                            [] => {
172                                // we have an early return case for that
173                                unreachable!();
174                            }
175                            &[level] => {
176                                // https://github.com/facebookincubator/fizz/blob/fff6d9d49d3c554ab66b58822d1e1fe93e8d80f2/fizz/experimental/ktls/AsyncKTLSSocket.cpp#L144
177                                //
178                                // Since all alerts (even warning-level alerts)
179                                // signal the abort of a TLS session, we do not
180                                // need to worry about additional application
181                                // data.
182                                //
183                                // If we only have half the alert (because the
184                                // user passed a buffer of size 1), just assume
185                                // it's a close_notify
186                                (
187                                    TlsAlertLevel::from_primitive(level),
188                                    TlsAlertDescription::CloseNotify,
189                                )
190                            }
191                            &[level, description] => (
192                                TlsAlertLevel::from_primitive(level),
193                                TlsAlertDescription::from_primitive(description),
194                            ),
195                            _ => {
196                                unreachable!(
197                                    "TLS alerts are exactly 2 bytes, your kTLS is misbehaving"
198                                );
199                            }
200                        };
201
202                        match (level, description) {
203                            // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
204                            // alerts we should handle are ones with fatal level or a
205                            // close_notify
206                            (_, TlsAlertDescription::CloseNotify) | (TlsAlertLevel::Fatal, _) => {
207                                tracing::trace!(?level, ?description, "got TLS alert");
208                                *this.read_closed = true;
209                                *this.write_closed = true;
210                                if let Err(e) =
211                                    crate::ffi::send_close_notify(this.inner.as_raw_fd())
212                                {
213                                    return Err(e).into();
214                                }
215                                // the file descriptor will be closed when the stream is dropped,
216                                // we already protect against writes-after-close_notify through
217                                // the write_closed flag
218                                return task::Poll::Ready(Ok(()));
219                            }
220                            _ => {
221                                // we got something we probably can't handle
222                            }
223                        }
224                        return task::Poll::Ready(Ok(()));
225                    }
226                    TlsGetRecordType::Handshake => {
227                        // TODO: this is where we receive TLS 1.3 resumption tickets,
228                        // should those be stored anywhere? I'm not even sure what
229                        // format they have at this point
230                        tracing::trace!(
231                            "ignoring handshake message (probably a resumption ticket)"
232                        );
233                    }
234                    TlsGetRecordType::ApplicationData => {
235                        unreachable!(
236                            "received TLS application in recvmsg, this is supposed to happen in \
237                             the poll_read codepath"
238                        )
239                    }
240                    TlsGetRecordType::Unknown(t) => {
241                        // just ignore the record?
242                        tracing::trace!("received record_type {t:#?}");
243                    }
244                    _ => {
245                        tracing::trace!("received unsupported record type");
246                    }
247                };
248
249                // FIXME: this is hacky, but can we do better?
250                // after we handled (..ignored) the control message, we don't
251                // know whether the socket is still ready to be read or not.
252                //
253                // we could try looping (tricky code structure), but we can't,
254                // for example, just call `poll_read`, which might fail not
255                // not with EAGAIN/EWOULDBLOCK, but because _another_ control
256                // message is available.
257                cx.waker().wake_by_ref();
258                return task::Poll::Pending;
259            }
260        }
261
262        read_res
263    }
264}
265
266impl<IO> AsyncWrite for KtlsStream<IO>
267where
268    IO: AsRawFd + AsyncWrite,
269{
270    fn poll_write(
271        self: Pin<&mut Self>,
272        cx: &mut task::Context<'_>,
273        buf: &[u8],
274    ) -> task::Poll<io::Result<usize>> {
275        if self.write_closed {
276            return task::Poll::Ready(Ok(0));
277        }
278
279        self.project().inner.poll_write(cx, buf)
280    }
281
282    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
283        self.project().inner.poll_flush(cx)
284    }
285
286    fn poll_shutdown(
287        self: Pin<&mut Self>,
288        cx: &mut task::Context<'_>,
289    ) -> task::Poll<io::Result<()>> {
290        let this = self.project();
291
292        if !*this.write_closed {
293            // they didn't hang up on us, we're nicely being asked to shut down,
294            // let's send a close_notify (and not wait for them to send it back)
295            *this.write_closed = true;
296            if let Err(e) = crate::ffi::send_close_notify(this.inner.as_raw_fd()) {
297                return Err(e).into();
298            }
299        }
300
301        // this ends up closing the inner file descriptor no matter what
302        this.inner.poll_shutdown(cx)
303    }
304}
305
306impl<IO> AsRawFd for KtlsStream<IO>
307where
308    IO: AsRawFd,
309{
310    fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
311        self.inner.as_raw_fd()
312    }
313}