ktls_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod log;
4pub mod prelude {
5    //! A "prelude".
6    //!
7    //! This prelude is similar to the standard library's prelude in that you'll
8    //! almost always want to import its entire contents, but unlike the
9    //! standard library's prelude you'll have to do so manually:
10    //!
11    //! ```
12    //! # #[allow(unused_imports)]
13    //! use ktls_stream::prelude::*;
14    //! ```
15    //!
16    //! The prelude may grow over time as additional items see ubiquitous use.
17    //!
18    //! Generally, you don't need to add `ktls-core` as a dependency in your
19    //! `Cargo.toml` unless you are implementing custom TLS session types, etc.
20
21    pub use ktls_core::setup_ulp;
22    #[cfg(feature = "probe-ktls-compatibility")]
23    pub use ktls_core::{Compatibilities, Compatibility};
24
25    pub use crate::Stream;
26}
27
28use std::io::{self, Read, Write};
29use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
30
31use ktls_core::{
32    setup_tls_params, setup_ulp, Buffer, Context, DummyTlsSession, ExtractedSecrets,
33    TlsCryptoInfoRx, TlsCryptoInfoTx, TlsSession,
34};
35
36pin_project_lite::pin_project! {
37    #[derive(Debug)]
38    #[project = StreamProj]
39    /// A thin wrapper around a kTLS offloaded socket.
40    ///
41    /// This implements [`Read`](std::io::Read) and [`Write`](std::io::Write),
42    /// [`AsyncRead`](tokio::io::AsyncRead) and
43    /// [`AsyncWrite`](tokio::io::AsyncWrite) (when feature `async-io-tokio` is
44    /// enabled).
45    ///
46    /// # Behaviours
47    ///
48    /// Once receives a `close_notify` alert from the peer, all subsequent read
49    /// operations will return EOF (unless the inner buffer contains unread data);
50    /// once the caller explicitly calls `(poll_)shutdown` on the stream, a
51    /// `close_notify` alert would be sent to the peer and all subsequent write
52    /// operations will return 0 bytes, indicating that the stream is closed for
53    /// writing. When the [`Stream`] is dropped, it will also perform graceful
54    /// shutdown automatically.
55    ///
56    /// For TLS 1.2, once one party sends a `close_notify` alert, *the other party
57    /// MUST respond with a `close_notify` alert of its own and close down the
58    /// connection immediately*, according to [RFC 5246, section 7.2.1]; for TLS
59    /// 1.3, *both parties need not wait to receive a "`close_notify`" alert before
60    /// closing their read side of the connection*, according to [RFC 8446, section
61    /// 6.1].
62    ///
63    /// [RFC 5246, section 7.2.1]: https://tools.ietf.org/html/rfc5246#section-7.2.1
64    /// [RFC 8446, section 6.1]: https://tools.ietf.org/html/rfc8446#section-6.1
65    pub struct Stream<S: AsFd, C: TlsSession> {
66        #[pin]
67        inner: S,
68
69        // Context of the kTLS connection.
70        context: Context<C>,
71    }
72
73    impl<S: AsFd, C: TlsSession> PinnedDrop for Stream<S, C> {
74        fn drop(this: Pin<&mut Self>) {
75            let this = this.project();
76
77            this.context.shutdown(&*this.inner);
78        }
79    }
80}
81
82impl<S: AsFd, C: TlsSession> Stream<S, C> {
83    /// Constructs a new [`Stream`] from the provided `socket`, extracted TLS
84    /// `secrets` and TLS `session` context. An optional `buffer` may be
85    /// provided for early data received during handshake.
86    ///
87    /// ## Prerequisites
88    ///
89    /// The socket must have TLS ULP configured with [`setup_ulp`].
90    ///
91    /// ## Errors
92    ///
93    /// Unsupported protocol version or cipher suite, or failure to set up
94    /// kTLS params on the socket.
95    pub fn new<K, E>(
96        socket: S,
97        secrets: K,
98        session: C,
99        buffer: Option<Buffer>,
100    ) -> Result<Self, ktls_core::Error>
101    where
102        ExtractedSecrets: TryFrom<K, Error = E>,
103        ktls_core::Error: From<E>,
104    {
105        let ExtractedSecrets {
106            tx: (seq_tx, secrets_tx),
107            rx: (seq_rx, secrets_rx),
108        } = ExtractedSecrets::try_from(secrets)?;
109
110        let tls_crypto_info_tx =
111            TlsCryptoInfoTx::new(session.protocol_version(), secrets_tx, seq_tx)?;
112
113        let tls_crypto_info_rx =
114            TlsCryptoInfoRx::new(session.protocol_version(), secrets_rx, seq_rx)?;
115
116        setup_tls_params(&socket, &tls_crypto_info_tx, &tls_crypto_info_rx)?;
117
118        Ok(Self {
119            inner: socket,
120            context: Context::new(session, buffer),
121        })
122    }
123
124    /// Returns a [`RawStreamMut`] which provides low-level access to the
125    /// inner socket.
126    ///
127    /// This requires a mutable reference to the [`Stream`] to ensure a
128    /// exclusive access to the inner socket.
129    ///
130    /// ## Notes
131    ///
132    /// * All buffered data **MUST** be properly consumed (See
133    ///   [`AccessRawStreamError::HasBufferedData`]).
134    ///
135    ///   The buffered data typically consists of:
136    ///
137    ///   - Early data received during handshake.
138    ///   - Application data received due to improper usage of
139    ///     [`RawStreamMut::handle_io_error`].
140    ///
141    /// * The caller **MAY** handle any [`io::Error`]s returned by direct I/O
142    ///   operations on the inner socket with [`RawStreamMut::handle_io_error`].
143    ///
144    /// * The caller **MUST NOT** *shutdown* the inner socket directly, which
145    ///   will lead to undefined behaviours.
146    ///
147    /// # Errors
148    ///
149    /// See [`AccessRawStreamError`].
150    pub fn as_mut_raw(&mut self) -> Result<RawStreamMut<'_, S, C>, AccessRawStreamError> {
151        if let Some(buffer) = self.context.buffer_mut().drain() {
152            return Err(AccessRawStreamError::HasBufferedData(buffer));
153        }
154
155        if self.context.state().is_closed() {
156            // Fully closed, just return error.
157            return Err(AccessRawStreamError::Closed);
158        }
159
160        Ok(RawStreamMut { this: self })
161    }
162
163    #[cfg(feature = "tls13-key-update")]
164    /// [`Context::refresh_traffic_keys`] against the inner socket.
165    ///
166    /// Use with caution, and do check [`Context::refresh_traffic_keys`] for
167    /// details.
168    ///
169    /// # Errors
170    ///
171    /// See [`Context::refresh_traffic_keys`].
172    pub fn refresh_traffic_keys(&mut self) -> Result<(), ktls_core::Error> {
173        self.context
174            .refresh_traffic_keys(&self.inner)
175    }
176}
177
178impl<S> Stream<S, DummyTlsSession>
179where
180    S: AsFd,
181{
182    #[inline]
183    /// Creates a new [`Stream`] with a [`DummyTlsSession`].
184    ///
185    /// This doesn't require the socket to have TLS ULP configured, we will
186    /// configure it here.
187    ///
188    /// See also [`Stream::new`].
189    ///
190    /// ## Errors
191    ///
192    /// See [`Stream::new`].
193    pub fn new_dummy(
194        socket: S,
195        secrets: ExtractedSecrets,
196        session: DummyTlsSession,
197        buffer: Option<Buffer>,
198    ) -> Result<Self, ktls_core::Error> {
199        setup_ulp(&socket)?;
200
201        Self::new(socket, secrets, session, buffer)
202    }
203}
204
205macro_rules! handle_ret {
206    ($this:expr, $($tt:tt)+) => {
207        loop {
208            let err = match $($tt)+ {
209                r @ Ok(_) => return r,
210                Err(err) => err,
211            };
212
213            $this.context.handle_io_error(&$this.inner, err)?;
214        }
215    };
216}
217
218impl<S, C> Read for Stream<S, C>
219where
220    S: AsFd + Read,
221    C: TlsSession,
222{
223    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
224        handle_ret!(self, {
225            let read_from_buffer = self.context.buffer_mut().read(|data| {
226                crate::trace!("Read from buffer: remaining {} bytes", data.len());
227
228                let amt = buf.len().min(data.len());
229                buf[..amt].copy_from_slice(&data[..amt]);
230                amt
231            });
232
233            if let Some(read_from_buffer) = read_from_buffer {
234                return Ok(read_from_buffer.get());
235            }
236
237            if self.context.state().is_read_closed() {
238                crate::trace!("Read closed, returning EOF");
239
240                return Ok(0);
241            }
242
243            // Retry is OK, the implementation of `Read` requires no data will be
244            // read into the buffer when error occurs.
245            self.inner.read(buf)
246        })
247    }
248}
249
250macro_rules! impl_shutdown {
251    ($ty:ty) => {
252        impl<C> Stream<$ty, C>
253        where
254            C: TlsSession,
255        {
256            /// Shuts down both read and write sides of the TLS stream.
257            pub fn shutdown(&mut self) {
258                let is_write_closed = self.context.state().is_write_closed();
259
260                self.context.shutdown(&self.inner);
261
262                if !is_write_closed {
263                    let _ = self
264                        .inner
265                        .shutdown(std::net::Shutdown::Write);
266                }
267            }
268        }
269    };
270}
271
272impl_shutdown!(std::net::TcpStream);
273impl_shutdown!(std::os::unix::net::UnixStream);
274
275impl<S, C> Write for Stream<S, C>
276where
277    S: AsFd + Write,
278    C: TlsSession,
279{
280    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
281        handle_ret!(self, {
282            if self.context.state().is_write_closed() {
283                crate::trace!("Write closed, returning EOF");
284
285                return Ok(0);
286            }
287
288            // Retry is OK, the implementation of `Write` requires no data will
289            // be written when error occurs.
290            self.inner.write(buf)
291        })
292    }
293
294    fn flush(&mut self) -> io::Result<()> {
295        handle_ret!(self, {
296            if self.context.state().is_write_closed() {
297                crate::trace!("Write closed, skipping flush");
298
299                return Ok(());
300            }
301
302            self.inner.flush()
303        })
304    }
305}
306
307#[cfg(feature = "async-io-tokio")]
308macro_rules! handle_ret_async {
309    ($this:expr, $($tt:tt)+) => {
310        loop {
311            let err = match $($tt)+ {
312                r @ std::task::Poll::Pending => return r,
313                r @ std::task::Poll::Ready(Ok(_)) => return r,
314                std::task::Poll::Ready(Err(err)) => err,
315            };
316
317            $this.context.handle_io_error(&*$this.inner, err)?;
318        }
319    };
320}
321
322#[cfg(feature = "async-io-tokio")]
323impl<S, C> tokio::io::AsyncRead for Stream<S, C>
324where
325    S: AsFd + tokio::io::AsyncRead,
326    C: TlsSession,
327{
328    fn poll_read(
329        self: std::pin::Pin<&mut Self>,
330        cx: &mut std::task::Context<'_>,
331        buf: &mut tokio::io::ReadBuf<'_>,
332    ) -> std::task::Poll<io::Result<()>> {
333        let mut this = self.project();
334
335        handle_ret_async!(this, {
336            let read_from_buffer = this.context.buffer_mut().read(|data| {
337                let amt = buf.remaining().min(data.len());
338
339                crate::trace!(
340                    "Read from buffer: remaining {} bytes, will read {} bytes",
341                    data.len(),
342                    amt
343                );
344
345                buf.put_slice(&data[..amt]);
346
347                amt
348            });
349
350            if read_from_buffer.is_some() {
351                return std::task::Poll::Ready(Ok(()));
352            }
353
354            if this.context.state().is_read_closed() {
355                crate::trace!("Read closed, returning EOF");
356
357                return std::task::Poll::Ready(Ok(()));
358            }
359
360            // Retry is OK, the implementation of `poll_read` requires no data will be
361            // read into the buffer when error occurs.
362            this.inner.as_mut().poll_read(cx, buf)
363        })
364    }
365}
366
367#[cfg(feature = "async-io-tokio")]
368impl<S, C> tokio::io::AsyncWrite for Stream<S, C>
369where
370    S: AsFd + tokio::io::AsyncWrite,
371    C: TlsSession,
372{
373    fn poll_write(
374        self: std::pin::Pin<&mut Self>,
375        cx: &mut std::task::Context<'_>,
376        buf: &[u8],
377    ) -> std::task::Poll<io::Result<usize>> {
378        let mut this = self.project();
379
380        handle_ret_async!(this, {
381            if this.context.state().is_write_closed() {
382                crate::trace!("Write closed, returning EOF");
383
384                return std::task::Poll::Ready(Ok(0));
385            }
386
387            this.inner.as_mut().poll_write(cx, buf)
388        })
389    }
390
391    fn poll_flush(
392        self: std::pin::Pin<&mut Self>,
393        cx: &mut std::task::Context<'_>,
394    ) -> std::task::Poll<io::Result<()>> {
395        let mut this = self.project();
396
397        handle_ret_async!(this, {
398            if this.context.state().is_write_closed() {
399                crate::trace!("Write closed, skipping flush");
400
401                return std::task::Poll::Ready(Ok(()));
402            }
403
404            this.inner.as_mut().poll_flush(cx)
405        })
406    }
407
408    fn poll_shutdown(
409        self: std::pin::Pin<&mut Self>,
410        cx: &mut std::task::Context<'_>,
411    ) -> std::task::Poll<io::Result<()>> {
412        let this = self.project();
413
414        let is_write_closed = this.context.state().is_write_closed();
415
416        // Notify the peer that we're going to close the write side.
417        this.context.shutdown(&*this.inner);
418
419        if is_write_closed {
420            std::task::Poll::Ready(Ok(()))
421        } else {
422            this.inner.poll_shutdown(cx)
423        }
424    }
425}
426
427/// See [`Stream::as_mut_raw`].
428pub struct RawStreamMut<'a, S: AsFd, C: TlsSession> {
429    this: &'a mut Stream<S, C>,
430}
431
432impl<S: AsFd, C: TlsSession> RawStreamMut<'_, S, C> {
433    /// Performs read operation on the inner socket, handles possible errors
434    /// with [`Context::handle_io_error`] and retries the operation if the
435    /// error is recoverable (see [`Context::handle_io_error`] for details).
436    ///
437    /// # Prerequisites
438    ///
439    /// The caller SHOULD NOT perform any *write* operations in `f`.
440    ///
441    /// # Errors
442    ///
443    /// - If the read side of the TLS stream is closed, this will return an EOF.
444    /// - Returns the original I/O error returned by `f` that is unrecoverable.
445    ///
446    ///   See also [`Context::handle_io_error`].
447    pub fn try_read_io<F, R>(&mut self, mut f: F) -> io::Result<R>
448    where
449        F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
450    {
451        if self
452            .this
453            .context
454            .state()
455            .is_read_closed()
456        {
457            crate::trace!("Read closed, returning EOF");
458
459            return Err(io::Error::new(
460                io::ErrorKind::UnexpectedEof,
461                "TLS stream (read side) is closed",
462            ));
463        }
464
465        handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
466    }
467
468    /// Performs write operation on the inner socket, handles possible errors
469    /// with [`Context::handle_io_error`] and retries the operation if the
470    /// error is recoverable (see [`Context::handle_io_error`] for details).
471    ///
472    /// # Prerequisites
473    ///
474    /// The caller SHOULD NOT perform any *read* operations in `f`.
475    ///
476    /// # Errors
477    ///
478    /// - If the write side of the TLS stream is closed, this will return an
479    ///   EOF.
480    /// - Returns the original I/O error returned by `f` that is unrecoverable.
481    ///
482    ///   See also [`Context::handle_io_error`].
483    pub fn try_write_io<F, R>(&mut self, mut f: F) -> io::Result<R>
484    where
485        F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
486    {
487        if self
488            .this
489            .context
490            .state()
491            .is_write_closed()
492        {
493            crate::trace!("Write closed, returning WriteZero");
494
495            return Err(io::Error::new(
496                io::ErrorKind::WriteZero,
497                "TLS stream (write side) is closed",
498            ));
499        }
500
501        handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
502    }
503
504    #[inline]
505    /// Since [`RawStreamMut`] provides direct access to the inner socket,
506    /// the caller **MUST** handle any possible I/O errors returned by I/O
507    /// operations on the inner socket with this method.
508    ///
509    /// See also [`Context::handle_io_error`].
510    ///
511    /// # Errors
512    ///
513    /// See [`Context::handle_io_error`].
514    pub fn handle_io_error(&mut self, err: io::Error) -> io::Result<()> {
515        self.this
516            .context
517            .handle_io_error(&self.this.inner, err)
518    }
519}
520
521impl<S: AsFd, C: TlsSession> AsFd for RawStreamMut<'_, S, C> {
522    #[inline]
523    fn as_fd(&self) -> BorrowedFd<'_> {
524        self.this.inner.as_fd()
525    }
526}
527
528impl<S: AsFd, C: TlsSession> AsRawFd for RawStreamMut<'_, S, C> {
529    #[inline]
530    fn as_raw_fd(&self) -> RawFd {
531        self.this.inner.as_fd().as_raw_fd()
532    }
533}
534
535#[non_exhaustive]
536#[derive(Debug)]
537/// An error indicating that the inner socket cannot be accessed directly.
538pub enum AccessRawStreamError {
539    /// The TLS connection is fully closed (both read and write sides).
540    Closed,
541
542    /// There's still buffered data that has not been retrieved yet.
543    ///
544    /// The buffered data typically consists of:
545    ///
546    /// - Early data received during handshake.
547    /// - Application data received due to improper usage of
548    ///   [`RawStreamMut::handle_io_error`].
549    HasBufferedData(Vec<u8>),
550}