ktls_core/
context.rs

1//! Kernel TLS connection context.
2
3use std::io;
4use std::os::fd::{AsFd, AsRawFd};
5
6use bitfield_struct::bitfield;
7
8use crate::error::{Error, InvalidMessage, PeerMisbehaved, Result};
9use crate::ffi::{recv_tls_record, send_tls_control_message};
10use crate::tls::{
11    AlertDescription, AlertLevel, ContentType, HandshakeType, KeyUpdateRequest, Peer,
12    ProtocolVersion, TlsSession,
13};
14use crate::utils::Buffer;
15
16#[derive(Debug)]
17/// The context for managing a kTLS connection.
18pub struct Context<C: TlsSession> {
19    // State of the current kTLS connection
20    state: State,
21
22    // Shared buffer
23    buffer: Buffer,
24
25    // TLS session
26    session: C,
27}
28
29impl<C: TlsSession> Context<C> {
30    /// Creates a new kTLS context with the given TLS session and optional
31    /// buffer (can be TLS early data received from peer during handshake, or
32    /// pre-allocated buffer).
33    pub fn new(session: C, buffer: Option<Buffer>) -> Self {
34        Self {
35            state: State::new(),
36            buffer: buffer.unwrap_or_default(),
37            session,
38        }
39    }
40
41    /// Returns the current kTLS connection state.
42    pub const fn state(&self) -> &State {
43        &self.state
44    }
45
46    /// Returns a reference to the buffer.
47    pub const fn buffer(&self) -> &Buffer {
48        &self.buffer
49    }
50
51    /// Returns a mutable reference to the buffer.
52    pub const fn buffer_mut(&mut self) -> &mut Buffer {
53        &mut self.buffer
54    }
55
56    /// Handles [`io::Error`]s from I/O operations on kTLS-configured sockets.
57    ///
58    /// # Overview
59    ///
60    /// When a socket is configured with kTLS, it can be used like a normal
61    /// socket for data transmission - the kernel transparently handles
62    /// encryption and decryption. However, TLS control messages (e.g., TLS
63    /// alerts) from peers cannot be processed automatically by the kernel,
64    /// which returns `EIO` to notify userspace.
65    ///
66    /// This method helps handle such errors appropriately:
67    ///
68    /// - **`EIO`**: Attempts to process any received TLS control messages.
69    ///   Returns `Ok(())` on success, allowing the caller to retry the
70    ///   operation.
71    /// - **`BrokenPipe`**: Marks the stream as closed.
72    /// - Other errors: Aborts the connection with an `internal_error` alert and
73    ///   returns the original error.
74    ///
75    /// # Errors
76    ///
77    /// Returns the original [`io::Error`] if it cannot be recovered from.
78    pub fn handle_io_error<S: AsFd>(&mut self, socket: &S, err: io::Error) -> io::Result<()> {
79        if err.raw_os_error() == Some(libc::EIO) {
80            crate::trace!("Received EIO, handling TLS control message");
81
82            self.handle_tls_control_message(socket)?;
83
84            return Ok(());
85        }
86
87        if err.kind() == io::ErrorKind::BrokenPipe {
88            crate::trace!("The underlying stream is closed (BrokenPipe)");
89
90            // No need to send alert, the peer has closed the
91            // connection abruptly.
92        } else {
93            self.send_tls_alert(socket, AlertLevel::Fatal, AlertDescription::InternalError);
94        }
95
96        self.state.set_is_read_closed(true);
97        self.state.set_is_write_closed(true);
98
99        Err(err)
100    }
101
102    #[allow(clippy::too_many_lines)]
103    /// Handles TLS control messages received by kernel.
104    ///
105    /// The caller SHOULD first check if the raw os error returned were
106    /// `EIO`, which indicates that there is a TLS control message available.
107    ///
108    /// But in fact, this method can be called even if there's no TLS control
109    /// message (not recommended to do so).
110    fn handle_tls_control_message<S: AsFd>(&mut self, socket: &S) -> Result<()> {
111        match recv_tls_record(socket.as_fd().as_raw_fd(), &mut self.buffer) {
112            Ok(ContentType::Handshake) => {
113                return self.handle_tls_control_message_handshake(socket);
114            }
115            Ok(ContentType::Alert) => {
116                if let &[level, desc] = self.buffer.unfilled_initialized() {
117                    return self.handle_tls_control_message_alert(
118                        socket,
119                        AlertLevel::from_int(level),
120                        AlertDescription::from_int(desc),
121                    );
122                }
123
124                // The peer sent an invalid alert. We send back an error
125                // and close the connection.
126
127                crate::error!(
128                    "Invalid alert message received: {:?}, {:?}",
129                    self.buffer.unfilled_initialized(),
130                    self.buffer
131                );
132
133                return self.abort(
134                    socket,
135                    InvalidMessage::MessageTooLarge,
136                    InvalidMessage::MessageTooLarge.description(),
137                );
138            }
139            Ok(ContentType::ChangeCipherSpec) => {
140                // ChangeCipherSpec should only be sent under the following conditions:
141                //
142                // * TLS 1.2: during a handshake or a rehandshake
143                // * TLS 1.3: during a handshake
144                //
145                // We don't have to worry about handling messages during a handshake
146                // and rustls does not support TLS 1.2 rehandshakes so we just emit
147                // an error here and abort the connection.
148
149                crate::warn!("Received unexpected ChangeCipherSpec message");
150
151                return self.abort(
152                    socket,
153                    PeerMisbehaved::IllegalMiddleboxChangeCipherSpec,
154                    PeerMisbehaved::IllegalMiddleboxChangeCipherSpec.description(),
155                );
156            }
157            Ok(ContentType::ApplicationData) => {
158                // This shouldn't happen in normal operation.
159
160                crate::warn!(
161                    "Received {} bytes of application data, unexpected usage",
162                    self.buffer.unfilled_initialized().len()
163                );
164
165                self.buffer.set_filled_all();
166            }
167            Ok(_content_type) => {
168                crate::error!(
169                    "Received unexpected TLS control message: content_type={_content_type:?}",
170                );
171
172                return self.abort(
173                    socket,
174                    InvalidMessage::InvalidContentType,
175                    InvalidMessage::InvalidContentType.description(),
176                );
177            }
178            Err(error) => {
179                crate::error!("Failed to receive TLS control message: {error}");
180
181                return self.abort(
182                    socket,
183                    Error::General(error),
184                    AlertDescription::InternalError,
185                );
186            }
187        }
188
189        Ok(())
190    }
191
192    #[allow(clippy::too_many_lines)]
193    /// Handles a TLS alert received from the peer.
194    fn handle_tls_control_message_handshake<S: AsFd>(&mut self, socket: &S) -> Result<()> {
195        let mut messages =
196            HandshakeMessagesIter::new(self.buffer.unfilled_initialized()).enumerate();
197
198        while let Some((idx, payload)) = messages.next() {
199            let Ok((handshake_type, payload)) = payload else {
200                return self.abort(
201                    socket,
202                    InvalidMessage::MessageTooShort,
203                    InvalidMessage::MessageTooShort.description(),
204                );
205            };
206
207            match handshake_type {
208                HandshakeType::KeyUpdate
209                    if self.session.protocol_version() == ProtocolVersion::TLSv1_3 =>
210                {
211                    if idx != 0 || messages.next().is_some() {
212                        crate::error!(
213                            "RFC 8446, section 5.1: Handshake messages MUST NOT span key changes."
214                        );
215
216                        return self.abort(
217                            socket,
218                            PeerMisbehaved::KeyEpochWithPendingFragment,
219                            PeerMisbehaved::KeyEpochWithPendingFragment.description(),
220                        );
221                    }
222
223                    let &[payload] = payload else {
224                        crate::error!(
225                            "Received invalid KeyUpdate message, expected 1 byte payload, got: \
226                             {:?}",
227                            payload
228                        );
229
230                        return self.abort(
231                            socket,
232                            InvalidMessage::InvalidKeyUpdate,
233                            InvalidMessage::InvalidKeyUpdate.description(),
234                        );
235                    };
236
237                    let key_update_request = KeyUpdateRequest::from_int(payload);
238
239                    if let Err(error) = self
240                        .session
241                        .update_rx_secret()
242                        .and_then(|secret| secret.set(socket))
243                    {
244                        return self.abort(socket, error, AlertDescription::InternalError);
245                    }
246
247                    match key_update_request {
248                        KeyUpdateRequest::UpdateNotRequested => {}
249                        KeyUpdateRequest::UpdateRequested => {
250                            // Notify the peer that we are updating our TX secret as well.
251                            if let Err(error) = send_tls_control_message(
252                                socket.as_fd().as_raw_fd(),
253                                ContentType::Handshake,
254                                &mut [
255                                    HandshakeType::KeyUpdate.to_int(), // typ
256                                    0,
257                                    0,
258                                    1, // length
259                                    KeyUpdateRequest::UpdateNotRequested.to_int(),
260                                ],
261                            )
262                            .map_err(Error::KeyUpdateFailed)
263                            {
264                                // Failed to notify the peer, abort the connection.
265                                crate::error!("Failed to send KeyUpdate message: {error}");
266
267                                return self.abort(socket, error, AlertDescription::InternalError);
268                            }
269
270                            if let Err(error) = self
271                                .session
272                                .update_tx_secret()
273                                .and_then(|secret| secret.set(socket))
274                            {
275                                crate::error!("Failed to update TX secret: {error}");
276
277                                return self.abort(socket, error, AlertDescription::InternalError);
278                            }
279                        }
280                        KeyUpdateRequest::Unknown(_payload) => {
281                            crate::warn!(
282                                "Received KeyUpdate message with unknown request value: {_payload}"
283                            );
284
285                            return self.abort(
286                                socket,
287                                InvalidMessage::InvalidKeyUpdate,
288                                InvalidMessage::InvalidKeyUpdate.description(),
289                            );
290                        }
291                    }
292                }
293                HandshakeType::NewSessionTicket
294                    if self.session.protocol_version() == ProtocolVersion::TLSv1_3 =>
295                {
296                    if self.session.peer() != Peer::Client {
297                        crate::warn!("TLS 1.2 peer sent a TLS 1.3 NewSessionTicket message");
298
299                        return self.abort(
300                            socket,
301                            InvalidMessage::UnexpectedMessage(
302                                "TLS 1.2 peer sent a TLS 1.3 NewSessionTicket message",
303                            ),
304                            AlertDescription::UnexpectedMessage,
305                        );
306                    }
307
308                    if let Err(error) = self
309                        .session
310                        .handle_new_session_ticket(payload)
311                    {
312                        return self.abort(socket, error, AlertDescription::InternalError);
313                    }
314                }
315                _ if self.session.protocol_version() == ProtocolVersion::TLSv1_3 => {
316                    crate::error!(
317                        "Unexpected handshake message for a TLS 1.3 connection: \
318                         typ={handshake_type:?}",
319                    );
320
321                    return self.abort(
322                        socket,
323                        InvalidMessage::UnexpectedMessage(
324                            "expected KeyUpdate or NewSessionTicket message",
325                        ),
326                        AlertDescription::UnexpectedMessage,
327                    );
328                }
329                _ => {
330                    crate::error!(
331                        "Unexpected handshake message: ver={:?}, typ={handshake_type:?}",
332                        self.session.protocol_version()
333                    );
334
335                    return self.abort(
336                        socket,
337                        InvalidMessage::UnexpectedMessage(
338                            "handshake messages are not expected on TLS 1.2 connections",
339                        ),
340                        AlertDescription::UnexpectedMessage,
341                    );
342                }
343            }
344        }
345
346        Ok(())
347    }
348
349    /// Handles a TLS alert received from the peer.
350    fn handle_tls_control_message_alert<S: AsFd>(
351        &mut self,
352        socket: &S,
353        level: AlertLevel,
354        desc: AlertDescription,
355    ) -> Result<()> {
356        match desc {
357            AlertDescription::CloseNotify
358                if self.session.protocol_version() == ProtocolVersion::TLSv1_2 =>
359            {
360                // RFC 5246, section 7.2.1: Unless some other fatal alert has been transmitted,
361                // each party is required to send a close_notify alert before closing the write
362                // side of the connection.  The other party MUST respond with a close_notify
363                // alert of its own and close down the connection immediately, discarding any
364                // pending writes.
365                crate::trace!("Received `close_notify` alert, should shutdown the TLS stream");
366
367                self.shutdown(socket);
368            }
369            AlertDescription::CloseNotify => {
370                // RFC 8446, section 6.1: Each party MUST send a "close_notify" alert before
371                // closing its write side of the connection, unless it has already sent some
372                // error alert. This does not have any effect on its read side of the
373                // connection. Note that this is a change from versions of TLS prior to TLS 1.3
374                // in which implementations were required to react to a "close_notify" by
375                // discarding pending writes and sending an immediate "close_notify" alert of
376                // their own. That previous requirement could cause truncation in the read
377                // side. Both parties need not wait to receive a "close_notify" alert before
378                // closing their read side of the connection, though doing so would introduce
379                // the possibility of truncation.
380
381                crate::trace!(
382                    "Received `close_notify` alert, should shutdown the read side of TLS stream"
383                );
384
385                self.state.set_is_read_closed(true);
386            }
387            _ if self.session.protocol_version() == ProtocolVersion::TLSv1_2
388                && level == AlertLevel::Warning =>
389            {
390                // RFC 5246, section 7.2.2: If an alert with a level of warning
391                // is sent and received, generally the connection can continue
392                // normally.
393
394                crate::warn!("Received non fatal alert, level={level:?}, desc: {desc:?}");
395            }
396            _ => {
397                // All other alerts are treated as fatal and result in us immediately shutting
398                // down the connection and emitting an error.
399
400                crate::error!("Received fatal alert, desc: {desc:?}");
401
402                self.state.set_is_read_closed(true);
403                self.state.set_is_write_closed(true);
404
405                return Err(Error::AlertReceived(desc));
406            }
407        }
408
409        Ok(())
410    }
411
412    /// Closes the read side of the kTLS connection and sends a `close_notify`
413    /// alert to the peer.
414    pub fn shutdown<S: AsFd>(&mut self, socket: &S) {
415        crate::trace!("Shutting down the TLS stream with `close_notify` alert...");
416
417        self.send_tls_alert(socket, AlertLevel::Warning, AlertDescription::CloseNotify);
418
419        if self.session.protocol_version() == ProtocolVersion::TLSv1_2 {
420            // See RFC 5246, section 7.2.1
421            self.state.set_is_read_closed(true);
422        }
423
424        self.state.set_is_write_closed(true);
425    }
426
427    /// Aborts the kTLS connection and sends a fatal alert to the peer.
428    fn abort<T, S, E, D>(&mut self, socket: &S, error: E, description: D) -> Result<T>
429    where
430        S: AsFd,
431        E: Into<Error>,
432        D: Into<AlertDescription>,
433    {
434        crate::trace!("Aborting the TLS stream with fatal alert...");
435
436        self.send_tls_alert(socket, AlertLevel::Fatal, description.into());
437
438        self.state.set_is_read_closed(true);
439        self.state.set_is_write_closed(true);
440
441        Err(error.into())
442    }
443
444    /// Sends a TLS alert to the peer.
445    fn send_tls_alert<S: AsFd>(
446        &mut self,
447        socket: &S,
448        level: AlertLevel,
449        description: AlertDescription,
450    ) {
451        if !self.state.is_write_closed() {
452            let _ = send_tls_control_message(
453                socket.as_fd().as_raw_fd(),
454                ContentType::Alert,
455                &mut [level.to_int(), description.to_int()],
456            )
457            .inspect_err(|_e| {
458                crate::trace!("Failed to send alert: {_e}");
459            });
460        }
461    }
462}
463
464#[bitfield(u8)]
465/// State of the kTLS connection.
466pub struct State {
467    /// Whether the read side is closed.
468    pub is_read_closed: bool,
469
470    /// Whether the write side is closed.
471    pub is_write_closed: bool,
472
473    #[bits(6)]
474    _reserved: u8,
475}
476
477impl State {
478    /// Returns whether the connection is fully closed (both read and write
479    /// sides).
480    #[must_use]
481    pub const fn is_closed(&self) -> bool {
482        self.is_read_closed() && self.is_write_closed()
483    }
484}
485
486struct HandshakeMessagesIter<'a> {
487    inner: Result<Option<&'a [u8]>, ()>,
488}
489
490impl<'a> HandshakeMessagesIter<'a> {
491    const fn new(payloads: &'a [u8]) -> Self {
492        Self {
493            inner: Ok(Some(payloads)),
494        }
495    }
496}
497
498impl<'a> Iterator for HandshakeMessagesIter<'a> {
499    type Item = Result<(HandshakeType, &'a [u8]), ()>;
500
501    fn next(&mut self) -> Option<Self::Item> {
502        match self.inner {
503            Ok(None) => None,
504            Ok(Some(&[typ, a, b, c, ref rest @ ..])) => {
505                let handshake_type = HandshakeType::from_int(typ);
506                let payload_length = u32::from_be_bytes([0, a, b, c]) as usize;
507
508                let Some((payload, rest)) = rest.split_at_checked(payload_length) else {
509                    crate::error!(
510                        "Received truncated handshake message payload, expected: \
511                         {payload_length}, actual: {}",
512                        rest.len()
513                    );
514
515                    self.inner = Err(());
516
517                    return Some(Err(()));
518                };
519
520                if rest.is_empty() {
521                    self.inner = Ok(None);
522                } else {
523                    self.inner = Ok(Some(rest));
524                }
525
526                Some(Ok((handshake_type, payload)))
527            }
528            Ok(Some(_truncated)) => {
529                crate::error!("Received truncated handshake message payload: {_truncated:?}");
530
531                self.inner = Err(());
532
533                Some(Err(()))
534            }
535            Err(()) => Some(Err(())),
536        }
537    }
538}