ktls_core/
setup.rs

1//! Transport Layer Security (TLS) is a Upper Layer Protocol (ULP) that runs
2//! over TCP. TLS provides end-to-end data integrity and confidentiality.
3//!
4//! Once the TCP connection is established, sets the TLS ULP, which allows us to
5//! set/get TLS socket options.
6//!
7//! This module provides the [`setup_ulp`] function, which sets the ULP (Upper
8//! Layer Protocol) to TLS for a TCP socket. The user can also determine whether
9//! the kernel supports kTLS with [`setup_ulp`].
10//!
11//! After the TLS handshake is completed, we have all the parameters required to
12//! move the data-path to the kernel. There is a separate socket option for
13//! moving the transmit and the receive into the kernel.
14//!
15//! This module provides the low-level [`setup_tls_params`] function, which sets
16//! the Kernel TLS parameters on the TCP socket, allowing the kernel to handle
17//! encryption and decryption of the TLS data.
18
19use std::marker::PhantomData;
20use std::os::fd::{AsFd, AsRawFd};
21use std::{fmt, io, mem};
22
23use nix::sys::socket::{setsockopt, sockopt};
24use zeroize::Zeroize;
25
26use crate::error::{Error, Result};
27use crate::tls::{AeadKey, ConnectionTrafficSecrets, ProtocolVersion};
28
29/// Sets the TLS Upper Layer Protocol (ULP).
30///
31/// This should be called before performing any I/O operations on the
32/// socket.
33///
34/// # Errors
35///
36/// The caller may check if the error is due to the running kernel not
37/// supporting kTLS (e.g., kernel module `tls` not being enabled or the
38/// kernel version being too old) with [`Error::is_ktls_unsupported`].
39pub fn setup_ulp<S: AsFd>(socket: &S) -> Result<()> {
40    setsockopt(socket, sockopt::TcpUlp::default(), b"tls")
41        .map_err(io::Error::from)
42        .map_err(Error::Ulp)
43}
44
45/// Sets the kTLS parameters on the socket after the TLS handshake is completed.
46///
47/// Notes that most kernels do not support setting up TLS crypto materials
48/// twice more times.
49///
50/// ## Errors
51///
52/// * Invalid crypto materials.
53/// * Syscall error.
54pub fn setup_tls_params<S: AsFd>(
55    socket: &S,
56    tx: &TlsCryptoInfoTx,
57    rx: &TlsCryptoInfoRx,
58) -> Result<()> {
59    tx.set(socket)?;
60    rx.set(socket)?;
61
62    Ok(())
63}
64
65/// A wrapper around the `libc::tls12_crypto_info_*` structs, use with setting
66/// up the kTLS r/w parameters on the TCP socket.
67///
68/// This is originated from the `nix` crate, which currently does not support
69/// `AES-128-CCM`, `SM4-*` or `ARIA-*`, so we implement our own version here.
70pub struct TlsCryptoInfo<D> {
71    inner: TlsCryptoInfoImpl,
72    _direction: PhantomData<D>,
73}
74
75impl fmt::Debug for TlsCryptoInfoImpl {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        f.debug_struct("TlsCryptoInfo").finish()
78    }
79}
80
81/// Type alias of [`TlsCryptoInfo`], for transmit direction.
82pub type TlsCryptoInfoTx = TlsCryptoInfo<Tx>;
83
84/// Type alias of [`TlsCryptoInfo`], for receive direction.
85pub type TlsCryptoInfoRx = TlsCryptoInfo<Rx>;
86
87#[non_exhaustive]
88/// Marker type for the "tx" (transmit) direction.
89pub struct Tx;
90
91#[non_exhaustive]
92/// Marker type for the "rx" (receive) direction.
93pub struct Rx;
94
95impl<D> TlsCryptoInfo<D> {
96    #[inline]
97    /// Creates a new [`TlsCryptoInfo`] from the given protocol version and
98    /// connection traffic secrets.
99    ///
100    /// # Errors
101    ///
102    /// Invalid protocol version (only TLS 1.2 and TLS 1.3 are supported).
103    pub fn new(
104        protocol_version: ProtocolVersion,
105        secrets: ConnectionTrafficSecrets,
106        seq: u64,
107    ) -> Result<Self> {
108        TlsCryptoInfoImpl::new(protocol_version, secrets, seq).map(|inner| Self {
109            inner,
110            _direction: PhantomData,
111        })
112    }
113}
114
115impl TlsCryptoInfoTx {
116    /// Sets the kTLS parameters on the given file descriptor for the transmit
117    /// direction.
118    ///
119    /// This is a low-level function, usually you don't need to call it
120    /// directly.
121    ///
122    /// # Errors
123    ///
124    /// Errors may include invalid crypto materials or syscall errors.
125    pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
126        self.inner
127            .set(socket, libc::TLS_TX)
128            .map_err(Error::CryptoMaterial)
129    }
130}
131
132impl TlsCryptoInfoRx {
133    /// Sets the kTLS parameters on the given file descriptor for the receive
134    /// direction.
135    ///
136    /// This is a low-level function, usually you don't need to call it
137    /// directly.
138    ///
139    /// # Errors
140    ///
141    /// Errors may include invalid crypto materials or syscall errors.
142    pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
143        self.inner
144            .set(socket, libc::TLS_RX)
145            .map_err(Error::CryptoMaterial)
146    }
147}
148
149#[repr(C)]
150enum TlsCryptoInfoImpl {
151    AesGcm128(libc::tls12_crypto_info_aes_gcm_128),
152    AesGcm256(libc::tls12_crypto_info_aes_gcm_256),
153    AesCcm128(libc::tls12_crypto_info_aes_ccm_128),
154    Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305),
155    Sm4Gcm(libc::tls12_crypto_info_sm4_gcm),
156    Sm4Ccm(libc::tls12_crypto_info_sm4_ccm),
157    Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128),
158    Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256),
159}
160
161impl TlsCryptoInfoImpl {
162    #[allow(unused_qualifications)]
163    #[allow(clippy::cast_possible_truncation)] // Since Rust 2021 doesn't have `size_of_val` included in prelude.
164    #[inline]
165    /// Sets the kTLS parameters on the given file descriptor.
166    fn set<S: AsFd>(&self, socket: &S, direction: libc::c_int) -> io::Result<()> {
167        let (ffi_ptr, ffi_len) = match self {
168            Self::AesGcm128(crypto_info) => (
169                <*const _>::cast(crypto_info),
170                mem::size_of_val(crypto_info) as libc::socklen_t,
171            ),
172            Self::AesGcm256(crypto_info) => (
173                <*const _>::cast(crypto_info),
174                mem::size_of_val(crypto_info) as libc::socklen_t,
175            ),
176            Self::AesCcm128(crypto_info) => (
177                <*const _>::cast(crypto_info),
178                mem::size_of_val(crypto_info) as libc::socklen_t,
179            ),
180            Self::Chacha20Poly1305(crypto_info) => (
181                <*const _>::cast(crypto_info),
182                mem::size_of_val(crypto_info) as libc::socklen_t,
183            ),
184            Self::Sm4Gcm(crypto_info) => (
185                <*const _>::cast(crypto_info),
186                mem::size_of_val(crypto_info) as libc::socklen_t,
187            ),
188            Self::Sm4Ccm(crypto_info) => (
189                <*const _>::cast(crypto_info),
190                mem::size_of_val(crypto_info) as libc::socklen_t,
191            ),
192            Self::Aria128Gcm(crypto_info) => (
193                <*const _>::cast(crypto_info),
194                mem::size_of_val(crypto_info) as libc::socklen_t,
195            ),
196            Self::Aria256Gcm(crypto_info) => (
197                <*const _>::cast(crypto_info),
198                mem::size_of_val(crypto_info) as libc::socklen_t,
199            ),
200        };
201
202        #[allow(unsafe_code)]
203        // SAFETY: syscall
204        let ret = unsafe {
205            libc::setsockopt(
206                socket.as_fd().as_raw_fd(),
207                libc::SOL_TLS,
208                direction,
209                ffi_ptr,
210                ffi_len,
211            )
212        };
213
214        if ret < 0 {
215            return Err(io::Error::last_os_error());
216        }
217
218        Ok(())
219    }
220
221    #[allow(clippy::too_many_lines)]
222    #[allow(clippy::needless_pass_by_value)]
223    /// Extract the [`TlsCryptoInfo`] from the given
224    /// [`ProtocolVersion`] and [`ConnectionTrafficSecrets`].
225    fn new(
226        protocol_version: ProtocolVersion,
227        secrets: ConnectionTrafficSecrets,
228        seq: u64,
229    ) -> Result<Self> {
230        let version = match protocol_version {
231            ProtocolVersion::TLSv1_2 => libc::TLS_1_2_VERSION,
232            ProtocolVersion::TLSv1_3 => libc::TLS_1_3_VERSION,
233            r => return Err(Error::UnsupportedProtocolVersion(r)),
234        };
235
236        let this = match secrets {
237            ConnectionTrafficSecrets::Aes128Gcm {
238                key: AeadKey(key),
239                iv,
240                salt,
241            } => Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 {
242                info: libc::tls_crypto_info {
243                    version,
244                    cipher_type: libc::TLS_CIPHER_AES_GCM_128,
245                },
246                iv,
247                key,
248                salt,
249                rec_seq: seq.to_be_bytes(),
250            }),
251            ConnectionTrafficSecrets::Aes256Gcm {
252                key: AeadKey(key),
253                iv,
254                salt,
255            } => Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 {
256                info: libc::tls_crypto_info {
257                    version,
258                    cipher_type: libc::TLS_CIPHER_AES_GCM_256,
259                },
260                iv,
261                key,
262                salt,
263                rec_seq: seq.to_be_bytes(),
264            }),
265            ConnectionTrafficSecrets::Chacha20Poly1305 {
266                key: AeadKey(key),
267                iv,
268                salt,
269            } => Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 {
270                info: libc::tls_crypto_info {
271                    version,
272                    cipher_type: libc::TLS_CIPHER_CHACHA20_POLY1305,
273                },
274                iv,
275                key,
276                salt,
277                rec_seq: seq.to_be_bytes(),
278            }),
279            ConnectionTrafficSecrets::Aes128Ccm {
280                key: AeadKey(key),
281                iv,
282                salt,
283            } => Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 {
284                info: libc::tls_crypto_info {
285                    version,
286                    cipher_type: libc::TLS_CIPHER_AES_CCM_128,
287                },
288                iv,
289                key,
290                salt,
291                rec_seq: seq.to_be_bytes(),
292            }),
293            ConnectionTrafficSecrets::Sm4Gcm {
294                key: AeadKey(key),
295                iv,
296                salt,
297            } => Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm {
298                info: libc::tls_crypto_info {
299                    version,
300                    cipher_type: libc::TLS_CIPHER_SM4_GCM,
301                },
302                iv,
303                key,
304                salt,
305                rec_seq: seq.to_be_bytes(),
306            }),
307            ConnectionTrafficSecrets::Sm4Ccm {
308                key: AeadKey(key),
309                iv,
310                salt,
311            } => Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm {
312                info: libc::tls_crypto_info {
313                    version,
314                    cipher_type: libc::TLS_CIPHER_SM4_CCM,
315                },
316                iv,
317                key,
318                salt,
319                rec_seq: seq.to_be_bytes(),
320            }),
321            ConnectionTrafficSecrets::Aria128Gcm {
322                key: AeadKey(key),
323                iv,
324                salt,
325            } => Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 {
326                info: libc::tls_crypto_info {
327                    version,
328                    cipher_type: libc::TLS_CIPHER_ARIA_GCM_128,
329                },
330                iv,
331                key,
332                salt,
333                rec_seq: seq.to_be_bytes(),
334            }),
335            ConnectionTrafficSecrets::Aria256Gcm {
336                key: AeadKey(key),
337                iv,
338                salt,
339            } => Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 {
340                info: libc::tls_crypto_info {
341                    version,
342                    cipher_type: libc::TLS_CIPHER_ARIA_GCM_256,
343                },
344                iv,
345                key,
346                salt,
347                rec_seq: seq.to_be_bytes(),
348            }),
349        };
350
351        Ok(this)
352    }
353}
354
355impl Drop for TlsCryptoInfoImpl {
356    fn drop(&mut self) {
357        #[allow(clippy::match_same_arms)]
358        match self {
359            Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 { key, .. }) => {
360                key.zeroize();
361            }
362            Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 { key, .. }) => {
363                key.zeroize();
364            }
365            Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 { key, .. }) => {
366                key.zeroize();
367            }
368            Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 { key, .. }) => {
369                key.zeroize();
370            }
371            Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm { key, .. }) => {
372                key.zeroize();
373            }
374            Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm { key, .. }) => {
375                key.zeroize();
376            }
377            Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 { key, .. }) => {
378                key.zeroize();
379            }
380            Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 { key, .. }) => {
381                key.zeroize();
382            }
383        }
384    }
385}