Skip to main content

miku_ktls/
ffi.rs

1use std::{io, mem::size_of_val, os::unix::io::RawFd};
2
3pub(crate) use ktls_sys::bindings;
4use rustls::{
5    internal::msgs::{enums::AlertLevel, message::Message},
6    AlertDescription, ConnectionTrafficSecrets, SupportedCipherSuite,
7};
8
9use crate::error::KtlsCompatibilityError;
10
11pub(crate) const TLS_1_2_VERSION_NUMBER: u16 = (((bindings::TLS_1_2_VERSION_MAJOR & 0xFF) as u16)
12    << 8)
13    | ((bindings::TLS_1_2_VERSION_MINOR & 0xFF) as u16);
14
15pub(crate) const TLS_1_3_VERSION_NUMBER: u16 = (((bindings::TLS_1_3_VERSION_MAJOR & 0xFF) as u16)
16    << 8)
17    | ((bindings::TLS_1_3_VERSION_MINOR & 0xFF) as u16);
18
19/// `setsockopt` level constant: TCP
20const SOL_TCP: libc::c_int = 6;
21
22/// `setsockopt` SOL_TCP name constant: "upper level protocol"
23const TCP_ULP: libc::c_int = 31;
24
25/// `setsockopt` level constant: TLS
26const SOL_TLS: libc::c_int = 282;
27
28/// `setsockopt` SOL_TLS level constant: transmit (write)
29const TLS_TX: libc::c_int = 1;
30
31/// `setsockopt` SOL_TLS level constant: receive (read)
32const TLS_RX: libc::c_int = 2;
33
34/// `setsockopt(fd, SOL_TCP, TCP_ULP, "tls", size_of("tls"))`
35pub(crate) fn setup_ulp(fd: RawFd) -> io::Result<()> {
36    unsafe {
37        if libc::setsockopt(
38            fd,
39            SOL_TCP,
40            TCP_ULP,
41            "tls".as_ptr() as *const libc::c_void,
42            3,
43        ) < 0
44        {
45            return Err(io::Error::last_os_error());
46        }
47    }
48
49    Ok(())
50}
51
52/// `setsockopt(fd, SOL_TLS, {TLS_TX or TLS_RX}, info, size_of(info))`
53pub(crate) fn setup_tls_info(
54    fd: RawFd,
55    dir: Direction,
56    info: CryptoInfo,
57) -> Result<(), crate::Error> {
58    unsafe {
59        if libc::setsockopt(fd, SOL_TLS, dir.as_c_int(), info.as_ptr(), info.size() as _) < 0 {
60            return Err(crate::Error::TlsCryptoInfoError(io::Error::last_os_error()));
61        }
62    }
63    Ok(())
64}
65
66#[derive(Debug, Clone, Copy)]
67/// `SOL_TLS` direction.
68pub enum Direction {
69    // Transmit
70    Tx,
71
72    // Receive
73    Rx,
74}
75
76impl Direction {
77    #[inline]
78    const fn as_c_int(self) -> libc::c_int {
79        match self {
80            Self::Tx => TLS_TX,
81            Self::Rx => TLS_RX,
82        }
83    }
84}
85
86#[allow(dead_code)]
87/// `SOL_TLS` crypto info.
88///
89/// This is a wrapper around the kernel structs.
90pub enum CryptoInfo {
91    AesGcm128(bindings::tls12_crypto_info_aes_gcm_128),
92    AesGcm256(bindings::tls12_crypto_info_aes_gcm_256),
93    AesCcm128(bindings::tls12_crypto_info_aes_ccm_128),
94    Chacha20Poly1305(bindings::tls12_crypto_info_chacha20_poly1305),
95    Sm4Gcm(bindings::tls12_crypto_info_sm4_gcm),
96    Sm4Ccm(bindings::tls12_crypto_info_sm4_ccm),
97}
98
99impl CryptoInfo {
100    /// Return the system struct as a raw pointer.
101    pub fn as_ptr(&self) -> *const libc::c_void {
102        match self {
103            Self::AesGcm128(info) => info as *const _ as *const libc::c_void,
104            Self::AesGcm256(info) => info as *const _ as *const libc::c_void,
105            Self::AesCcm128(info) => info as *const _ as *const libc::c_void,
106            Self::Chacha20Poly1305(info) => info as *const _ as *const libc::c_void,
107            Self::Sm4Gcm(info) => info as *const _ as *const libc::c_void,
108            Self::Sm4Ccm(info) => info as *const _ as *const libc::c_void,
109        }
110    }
111
112    #[inline]
113    /// Return the system struct size.
114    pub fn size(&self) -> usize {
115        match self {
116            Self::AesGcm128(info) => size_of_val(info),
117            Self::AesGcm256(info) => size_of_val(info),
118            Self::AesCcm128(info) => size_of_val(info),
119            Self::Chacha20Poly1305(info) => size_of_val(info),
120            Self::Sm4Gcm(info) => size_of_val(info),
121            Self::Sm4Ccm(info) => size_of_val(info),
122        }
123    }
124}
125
126impl CryptoInfo {
127    /// Try to convert rustls cipher suite and secrets into a `CryptoInfo`.
128    pub fn from_rustls(
129        cipher_suite: SupportedCipherSuite,
130        (seq, secrets): (u64, ConnectionTrafficSecrets),
131    ) -> Result<CryptoInfo, KtlsCompatibilityError> {
132        let version = match cipher_suite {
133            SupportedCipherSuite::Tls12(..) => TLS_1_2_VERSION_NUMBER,
134            SupportedCipherSuite::Tls13(..) => TLS_1_3_VERSION_NUMBER,
135        };
136
137        Ok(match secrets {
138            ConnectionTrafficSecrets::Aes128Gcm { key, iv } => {
139                // see https://github.com/rustls/rustls/issues/1833, between
140                // rustls 0.21 and 0.22, the extract_keys codepath was changed,
141                // so, for TLS 1.2, both GCM-128 and GCM-256 return the
142                // Aes128Gcm variant.
143
144                match key.as_ref().len() {
145                    16 => CryptoInfo::AesGcm128(bindings::tls12_crypto_info_aes_gcm_128 {
146                        info: bindings::tls_crypto_info {
147                            version,
148                            cipher_type: bindings::TLS_CIPHER_AES_GCM_128 as _,
149                        },
150                        iv: iv
151                            .as_ref()
152                            .get(4..)
153                            .expect("AES-GCM-128 iv is 8 bytes")
154                            .try_into()
155                            .expect("AES-GCM-128 iv is 8 bytes"),
156                        key: key
157                            .as_ref()
158                            .try_into()
159                            .expect("AES-GCM-128 key is 16 bytes"),
160                        salt: iv
161                            .as_ref()
162                            .get(..4)
163                            .expect("AES-GCM-128 salt is 4 bytes")
164                            .try_into()
165                            .expect("AES-GCM-128 salt is 4 bytes"),
166                        rec_seq: seq.to_be_bytes(),
167                    }),
168                    32 => CryptoInfo::AesGcm256(bindings::tls12_crypto_info_aes_gcm_256 {
169                        info: bindings::tls_crypto_info {
170                            version,
171                            cipher_type: bindings::TLS_CIPHER_AES_GCM_256 as _,
172                        },
173                        iv: iv
174                            .as_ref()
175                            .get(4..)
176                            .expect("AES-GCM-256 iv is 8 bytes")
177                            .try_into()
178                            .expect("AES-GCM-256 iv is 8 bytes"),
179                        key: key
180                            .as_ref()
181                            .try_into()
182                            .expect("AES-GCM-256 key is 32 bytes"),
183                        salt: iv
184                            .as_ref()
185                            .get(..4)
186                            .expect("AES-GCM-256 salt is 4 bytes")
187                            .try_into()
188                            .expect("AES-GCM-256 salt is 4 bytes"),
189                        rec_seq: seq.to_be_bytes(),
190                    }),
191                    _ => unreachable!("GCM key length is not 16 or 32"),
192                }
193            }
194            ConnectionTrafficSecrets::Aes256Gcm { key, iv } => {
195                CryptoInfo::AesGcm256(bindings::tls12_crypto_info_aes_gcm_256 {
196                    info: bindings::tls_crypto_info {
197                        version,
198                        cipher_type: bindings::TLS_CIPHER_AES_GCM_256 as _,
199                    },
200                    iv: iv
201                        .as_ref()
202                        .get(4..)
203                        .expect("AES-GCM-256 iv is 8 bytes")
204                        .try_into()
205                        .expect("AES-GCM-256 iv is 8 bytes"),
206                    key: key
207                        .as_ref()
208                        .try_into()
209                        .expect("AES-GCM-256 key is 32 bytes"),
210                    salt: iv
211                        .as_ref()
212                        .get(..4)
213                        .expect("AES-GCM-256 salt is 4 bytes")
214                        .try_into()
215                        .expect("AES-GCM-256 salt is 4 bytes"),
216                    rec_seq: seq.to_be_bytes(),
217                })
218            }
219            ConnectionTrafficSecrets::Chacha20Poly1305 { key, iv } => {
220                CryptoInfo::Chacha20Poly1305(bindings::tls12_crypto_info_chacha20_poly1305 {
221                    info: bindings::tls_crypto_info {
222                        version,
223                        cipher_type: bindings::TLS_CIPHER_CHACHA20_POLY1305 as _,
224                    },
225                    iv: iv
226                        .as_ref()
227                        .try_into()
228                        .expect("Chacha20-Poly1305 iv is 12 bytes"),
229                    key: key
230                        .as_ref()
231                        .try_into()
232                        .expect("Chacha20-Poly1305 key is 32 bytes"),
233                    salt: bindings::__IncompleteArrayField::new(),
234                    rec_seq: seq.to_be_bytes(),
235                })
236            }
237            _ => {
238                return Err(KtlsCompatibilityError::UnsupportedCipherSuite(cipher_suite));
239            }
240        })
241    }
242}
243
244const TLS_SET_RECORD_TYPE: libc::c_int = 1;
245const ALERT: u8 = 0x15;
246
247// Yes, really. cmsg components are aligned to [libc::c_long]
248#[cfg_attr(target_pointer_width = "32", repr(C, align(4)))]
249#[cfg_attr(target_pointer_width = "64", repr(C, align(8)))]
250struct Cmsg<const N: usize> {
251    hdr: libc::cmsghdr,
252    data: [u8; N],
253}
254
255impl<const N: usize> Cmsg<N> {
256    fn new(level: i32, typ: i32, data: [u8; N]) -> Self {
257        Self {
258            hdr: libc::cmsghdr {
259                // on Linux this is a usize, on macOS this is a u32
260                #[allow(clippy::unnecessary_cast)]
261                cmsg_len: (memoffset::offset_of!(Self, data) + N) as _,
262                cmsg_level: level,
263                cmsg_type: typ,
264            },
265            data,
266        }
267    }
268}
269
270pub(crate) fn send_close_notify(fd: RawFd) -> std::io::Result<()> {
271    let mut data = vec![];
272    Message::build_alert(AlertLevel::Warning, AlertDescription::CloseNotify)
273        .payload
274        .encode(&mut data);
275
276    let mut cmsg = Cmsg::new(SOL_TLS, TLS_SET_RECORD_TYPE, [ALERT]);
277
278    let msg = libc::msghdr {
279        msg_name: std::ptr::null_mut(),
280        msg_namelen: 0,
281        msg_iov: &mut libc::iovec {
282            iov_base: data.as_mut_ptr() as _,
283            iov_len: data.len(),
284        },
285        msg_iovlen: 1,
286        msg_control: &mut cmsg as *mut _ as *mut _,
287        msg_controllen: cmsg.hdr.cmsg_len,
288        msg_flags: 0,
289    };
290
291    let ret = unsafe { libc::sendmsg(fd, &msg, 0) };
292    if ret < 0 {
293        return Err(io::Error::last_os_error());
294    }
295    Ok(())
296}