ktls_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod log;
4
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
7#[cfg(feature = "async-io-tokio")]
8use std::pin::Pin;
9#[cfg(feature = "async-io-tokio")]
10use std::task;
11
12use ktls_core::utils::Buffer;
13use ktls_core::{Context, TlsSession};
14#[cfg(feature = "async-io-tokio")]
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17pin_project_lite::pin_project! {
18    #[derive(Debug)]
19    #[project = StreamProj]
20    /// A thin wrapper around a socket with kernel TLS (kTLS) offload configured.
21    ///
22    /// This implements traits [`Read`](std::io::Read) and
23    /// [`Write`](std::io::Write), [`AsyncRead`](tokio::io::AsyncRead) and
24    /// [`AsyncWrite`](tokio::io::AsyncWrite) (when feature `async-io-tokio` is
25    /// enabled).
26    ///
27    /// ## Behaviours
28    ///
29    /// Once a TLS `close_notify` alert from the peer is received, all subsequent
30    /// read operations will return EOF.
31    ///
32    /// Once the caller explicitly calls `(poll_)shutdown` on the stream, all
33    /// subsequent write operations will return 0 bytes, indicating that the
34    /// stream is closed for writing.
35    ///
36    /// Once the stream is being dropped, a `close_notify` alert would be sent to
37    /// the peer automatically before shutting down the inner socket, according to
38    /// [RFC 8446, section 6.1].
39    ///
40    /// The caller may call `(poll_)shutdown` on the stream to shutdown explicitly
41    /// both sides of the stream. Currently, there's no way provided by this crate
42    /// to shutdown the TLS stream write side only. For TLS 1.2, this is ideal since
43    /// once one party sends a `close_notify` alert, *the other party MUST respond
44    /// with a `close_notify` alert of its own and close down the connection
45    /// immediately*, according to [RFC 5246, section 7.2.1]; for TLS 1.3, *both
46    /// parties need not wait to receive a "`close_notify`" alert before
47    /// closing their read side of the connection*, according to [RFC 8446, section
48    /// 6.1].
49    ///
50    /// [RFC 5246, section 7.2.1]: https://tools.ietf.org/html/rfc5246#section-7.2.1
51    /// [RFC 8446, section 6.1]: https://tools.ietf.org/html/rfc8446#section-6.1
52    pub struct Stream<S: AsFd, C: TlsSession> {
53        #[pin]
54        inner: S,
55
56        // Context of the kTLS connection.
57        context: Context<C>,
58    }
59
60    impl<S: AsFd, C: TlsSession> PinnedDrop for Stream<S, C> {
61        fn drop(this: Pin<&mut Self>) {
62            let this = this.project();
63
64            this.context.shutdown(&*this.inner);
65        }
66    }
67}
68
69impl<S: AsFd, C: TlsSession> Stream<S, C> {
70    /// Creates a new kTLS stream from the given socket, TLS session and an
71    /// optional buffer (may be early data received from peer during
72    /// handshaking).
73    ///
74    /// # Prerequisites
75    ///
76    /// - The socket must have TLS ULP configured with
77    ///   [`setup_ulp`](ktls_core::setup_ulp).
78    /// - The TLS handshake must be completed.
79    pub fn new(socket: S, session: C, buffer: Option<Buffer>) -> Self {
80        Self {
81            inner: socket,
82            context: Context::new(session, buffer),
83        }
84    }
85
86    /// Returns a mutable reference to the inner socket if the TLS connection is
87    /// not closed (unidirectionally or bidirectionally).
88    ///
89    /// This requires a mutable reference to the [`Stream`] to ensure a
90    /// exclusive access to the inner socket.
91    ///
92    /// ## Notes
93    ///
94    /// * All buffered data **MUST** be properly consumed (See
95    ///   [`AccessRawStreamError::HasBufferedData`]).
96    ///
97    ///   The buffered data typically consists of:
98    ///
99    ///   - Early data received during handshake.
100    ///   - Application data received due to improper usage of
101    ///     [`StreamRefMutRaw::handle_io_error`].
102    ///
103    /// * The caller **MAY** handle any [`io::Result`]s returned by I/O
104    ///   operations on the inner socket with
105    ///   [`StreamRefMutRaw::handle_io_error`].
106    ///
107    /// * The caller **MUST NOT** shutdown the inner socket directly, which will
108    ///   lead to undefined behaviours. Instead, the caller **MAY** call
109    ///   `(poll_)shutdown` explictly on the [`Stream`] to gracefully shutdown
110    ///   the TLS stream (with `close_notify` be sent) manually, or just drop
111    ///   the stream to do automatic graceful shutdown.
112    /// 
113    /// # Errors
114    /// 
115    /// See [`AccessRawStreamError`].
116    pub fn as_mut_raw(&mut self) -> Result<StreamRefMutRaw<'_, S, C>, AccessRawStreamError> {
117        if let Some(buffer) = self.context.buffer_mut().drain() {
118            return Err(AccessRawStreamError::HasBufferedData(buffer));
119        }
120
121        let state = self.context.state();
122
123        if state.is_closed() {
124            // Fully closed, just return error.
125            return Err(AccessRawStreamError::Closed);
126        }
127
128        Ok(StreamRefMutRaw { this: self })
129    }
130}
131
132#[cfg(feature = "shim-rustls")]
133impl<S, Data> Stream<S, rustls::kernel::KernelConnection<Data>>
134where
135    S: AsFd,
136    rustls::kernel::KernelConnection<Data>: TlsSession,
137{
138    /// Constructs a new [`Stream`] from a socket, TLS secrets, and TLS session
139    /// context.
140    ///
141    /// # Overview
142    ///
143    /// This creates a [`Stream`] from the provided socket, extracted TLS
144    /// secrets ([`rustls::ExtractedSecrets`]), and TLS session context
145    /// ([`rustls::kernel::KernelConnection`]). An optional buffer may be
146    /// provided for early data received during handshake.
147    ///
148    /// The secrets and context must be extracted from a
149    /// [`rustls::client::UnbufferedClientConnection`] or
150    /// [`rustls::client::UnbufferedClientConnection`]. See [`rustls::kernel`]
151    /// module documentation for more details.
152    ///
153    /// ## Prerequisites
154    ///
155    /// The socket must have TLS ULP configured with
156    /// [`setup_ulp`](ktls_core::setup_ulp).
157    ///
158    /// ## Errors
159    ///
160    /// Returns an error if prerequisites aren't met or kernel TLS setup fails.
161    pub fn from(
162        socket: S,
163        secrets: rustls::ExtractedSecrets,
164        session: rustls::kernel::KernelConnection<Data>,
165        buffer: Option<Buffer>,
166    ) -> Result<Self, ktls_core::Error> {
167        use ktls_core::{TlsCryptoInfoRx, TlsCryptoInfoTx};
168
169        let rustls::ExtractedSecrets {
170            tx: (seq_tx, secrets_tx),
171            rx: (seq_rx, secrets_rx),
172        } = secrets;
173
174        let tls_crypto_info_tx = TlsCryptoInfoTx::new(
175            session.protocol_version().into(),
176            secrets_tx.try_into()?,
177            seq_tx,
178        )?;
179
180        let tls_crypto_info_rx = TlsCryptoInfoRx::new(
181            session.protocol_version().into(),
182            secrets_rx.try_into()?,
183            seq_rx,
184        )?;
185
186        ktls_core::setup_tls_params(&socket, &tls_crypto_info_tx, &tls_crypto_info_rx)?;
187
188        Ok(Self::new(socket, session, buffer))
189    }
190}
191
192macro_rules! handle_ret {
193    ($this:expr, $($tt:tt)+) => {
194        loop {
195            let err = match $($tt)+ {
196                r @ Ok(_) => return r,
197                Err(err) => err,
198            };
199
200            $this.context.handle_io_error(&$this.inner, err)?;
201        }
202    };
203}
204
205impl<S, C> Read for Stream<S, C>
206where
207    S: AsFd + Read,
208    C: TlsSession,
209{
210    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
211        handle_ret!(self, {
212            if self.context.state().is_read_closed() {
213                return Ok(0);
214            }
215
216            let read_from_buffer = self.context.buffer_mut().read(|data| {
217                let amt = buf.len().min(data.len());
218                buf[..amt].copy_from_slice(&data[..amt]);
219                amt
220            });
221
222            if let Some(read_from_buffer) = read_from_buffer {
223                return Ok(read_from_buffer.get());
224            }
225
226            // Retry is OK, the implementation of `Read` requires no data will be
227            // read into the buffer when error occurs.
228            self.inner.read(buf)
229        })
230    }
231}
232
233macro_rules! impl_shutdown {
234    ($ty:ty) => {
235        impl<C> Stream<$ty, C>
236        where
237            C: TlsSession,
238        {
239            /// Shuts down both read and write sides of the TLS stream.
240            pub fn shutdown(&mut self) {
241                let is_write_closed = self.context.state().is_write_closed();
242
243                self.context.shutdown(&self.inner);
244
245                if !is_write_closed {
246                    let _ = self
247                        .inner
248                        .shutdown(std::net::Shutdown::Write);
249                }
250            }
251        }
252    };
253}
254
255impl_shutdown!(std::net::TcpStream);
256impl_shutdown!(std::os::unix::net::UnixStream);
257
258impl<S, C> Write for Stream<S, C>
259where
260    S: AsFd + Write,
261    C: TlsSession,
262{
263    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
264        handle_ret!(self, {
265            if self.context.state().is_write_closed() {
266                // Write side is closed, return EOF.
267                return Ok(0);
268            }
269
270            // Retry is OK, the implementation of `Write` requires no data will
271            // be written when error occurs.
272            self.inner.write(buf)
273        })
274    }
275
276    fn flush(&mut self) -> io::Result<()> {
277        handle_ret!(self, {
278            if self.context.state().is_write_closed() {
279                // Write side is closed, return directly.
280                return Ok(());
281            }
282
283            self.inner.flush()
284        })
285    }
286}
287
288#[cfg(feature = "async-io-tokio")]
289macro_rules! handle_ret_async {
290    ($this:expr, $($tt:tt)+) => {
291        loop {
292            let err = match $($tt)+ {
293                r @ std::task::Poll::Pending => return r,
294                r @ std::task::Poll::Ready(Ok(_)) => return r,
295                std::task::Poll::Ready(Err(err)) => err,
296            };
297
298            $this.context.handle_io_error(&*$this.inner, err)?;
299        }
300    };
301}
302
303#[cfg(feature = "async-io-tokio")]
304impl<S, C> AsyncRead for Stream<S, C>
305where
306    S: AsFd + AsyncRead,
307    C: TlsSession,
308{
309    fn poll_read(
310        self: Pin<&mut Self>,
311        cx: &mut task::Context<'_>,
312        buf: &mut ReadBuf<'_>,
313    ) -> task::Poll<io::Result<()>> {
314        let mut this = self.project();
315
316        handle_ret_async!(this, {
317            if this.context.state().is_read_closed() {
318                return task::Poll::Ready(Ok(()));
319            }
320
321            this.context.buffer_mut().read(|data| {
322                let amt = buf.remaining().min(data.len());
323                buf.put_slice(&data[..amt]);
324                amt
325            });
326
327            this.inner.as_mut().poll_read(cx, buf)
328        })
329    }
330}
331
332#[cfg(feature = "async-io-tokio")]
333impl<S, C> AsyncWrite for Stream<S, C>
334where
335    S: AsFd + AsyncWrite,
336    C: TlsSession,
337{
338    fn poll_write(
339        self: Pin<&mut Self>,
340        cx: &mut task::Context<'_>,
341        buf: &[u8],
342    ) -> task::Poll<io::Result<usize>> {
343        let mut this = self.project();
344
345        handle_ret_async!(this, {
346            if this.context.state().is_write_closed() {
347                crate::trace!("Write closed, returning EOF");
348
349                return task::Poll::Ready(Ok(0));
350            }
351
352            this.inner.as_mut().poll_write(cx, buf)
353        })
354    }
355
356    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
357        let mut this = self.project();
358
359        handle_ret_async!(this, {
360            if this.context.state().is_write_closed() {
361                crate::trace!("Write closed, skipping flush");
362
363                return task::Poll::Ready(Ok(()));
364            }
365
366            this.inner.as_mut().poll_flush(cx)
367        })
368    }
369
370    fn poll_shutdown(
371        self: Pin<&mut Self>,
372        cx: &mut task::Context<'_>,
373    ) -> task::Poll<io::Result<()>> {
374        let this = self.project();
375
376        let is_write_closed = this.context.state().is_write_closed();
377
378        // Notify the peer that we're going to close the write side.
379        this.context.shutdown(&*this.inner);
380
381        if is_write_closed {
382            task::Poll::Ready(Ok(()))
383        } else {
384            this.inner.poll_shutdown(cx)
385        }
386    }
387}
388
389/// See [`Stream::as_mut_raw`].
390pub struct StreamRefMutRaw<'a, S: AsFd, C: TlsSession> {
391    this: &'a mut Stream<S, C>,
392}
393
394impl<S: AsFd, C: TlsSession> StreamRefMutRaw<'_, S, C> {
395    /// Performs an I/O operation on the inner socket, handling possible errors
396    /// with [`Context::handle_io_error`].
397    ///
398    /// # Errors
399    ///
400    /// Returns the original I/O error that is unrecoverable.
401    pub fn try_io<F, R>(&mut self, mut f: F) -> io::Result<R>
402    where
403        F: FnMut(&mut S) -> io::Result<R>,
404    {
405        handle_ret!(self.this, f(&mut self.this.inner));
406    }
407
408    /// See [`Context::handle_io_error`].
409    ///
410    /// # Errors
411    ///
412    /// Returns the original I/O error that is unrecoverable.
413    pub fn handle_io_error(&mut self, err: io::Error) -> io::Result<()> {
414        self.this
415            .context
416            .handle_io_error(&self.this.inner, err)
417    }
418}
419
420impl<S: AsFd, C: TlsSession> AsFd for StreamRefMutRaw<'_, S, C> {
421    fn as_fd(&self) -> BorrowedFd<'_> {
422        self.this.inner.as_fd()
423    }
424}
425
426impl<S: AsFd, C: TlsSession> AsRawFd for StreamRefMutRaw<'_, S, C> {
427    fn as_raw_fd(&self) -> RawFd {
428        self.this.inner.as_fd().as_raw_fd()
429    }
430}
431
432#[non_exhaustive]
433#[derive(Debug)]
434/// An error indicating that the inner socket cannot be accessed directly.
435pub enum AccessRawStreamError {
436    /// The TLS connection is fully closed (both read and write sides).
437    Closed,
438
439    /// There's still buffered data that has not been retrieved yet.
440    HasBufferedData(Vec<u8>),
441}