1use std::{io, mem::size_of_val, os::unix::io::RawFd};
2
3pub(crate) use ktls_sys::bindings;
4use rustls::{
5 internal::msgs::{enums::AlertLevel, message::Message},
6 AlertDescription, ConnectionTrafficSecrets, SupportedCipherSuite,
7};
8
9use crate::error::KtlsCompatibilityError;
10
11pub(crate) const TLS_1_2_VERSION_NUMBER: u16 = (((bindings::TLS_1_2_VERSION_MAJOR & 0xFF) as u16)
12 << 8)
13 | ((bindings::TLS_1_2_VERSION_MINOR & 0xFF) as u16);
14
15pub(crate) const TLS_1_3_VERSION_NUMBER: u16 = (((bindings::TLS_1_3_VERSION_MAJOR & 0xFF) as u16)
16 << 8)
17 | ((bindings::TLS_1_3_VERSION_MINOR & 0xFF) as u16);
18
19const SOL_TCP: libc::c_int = 6;
21
22const TCP_ULP: libc::c_int = 31;
24
25const SOL_TLS: libc::c_int = 282;
27
28const TLS_TX: libc::c_int = 1;
30
31const TLS_RX: libc::c_int = 2;
33
34pub(crate) fn setup_ulp(fd: RawFd) -> io::Result<()> {
36 unsafe {
37 if libc::setsockopt(
38 fd,
39 SOL_TCP,
40 TCP_ULP,
41 "tls".as_ptr() as *const libc::c_void,
42 3,
43 ) < 0
44 {
45 return Err(io::Error::last_os_error());
46 }
47 }
48
49 Ok(())
50}
51
52pub(crate) fn setup_tls_info(
54 fd: RawFd,
55 dir: Direction,
56 info: CryptoInfo,
57) -> Result<(), crate::Error> {
58 unsafe {
59 if libc::setsockopt(fd, SOL_TLS, dir.as_c_int(), info.as_ptr(), info.size() as _) < 0 {
60 return Err(crate::Error::TlsCryptoInfoError(io::Error::last_os_error()));
61 }
62 }
63 Ok(())
64}
65
66#[derive(Debug, Clone, Copy)]
67pub enum Direction {
69 Tx,
71
72 Rx,
74}
75
76impl Direction {
77 #[inline]
78 const fn as_c_int(self) -> libc::c_int {
79 match self {
80 Self::Tx => TLS_TX,
81 Self::Rx => TLS_RX,
82 }
83 }
84}
85
86#[allow(dead_code)]
87pub enum CryptoInfo {
91 AesGcm128(bindings::tls12_crypto_info_aes_gcm_128),
92 AesGcm256(bindings::tls12_crypto_info_aes_gcm_256),
93 AesCcm128(bindings::tls12_crypto_info_aes_ccm_128),
94 Chacha20Poly1305(bindings::tls12_crypto_info_chacha20_poly1305),
95 Sm4Gcm(bindings::tls12_crypto_info_sm4_gcm),
96 Sm4Ccm(bindings::tls12_crypto_info_sm4_ccm),
97}
98
99impl CryptoInfo {
100 pub fn as_ptr(&self) -> *const libc::c_void {
102 match self {
103 Self::AesGcm128(info) => info as *const _ as *const libc::c_void,
104 Self::AesGcm256(info) => info as *const _ as *const libc::c_void,
105 Self::AesCcm128(info) => info as *const _ as *const libc::c_void,
106 Self::Chacha20Poly1305(info) => info as *const _ as *const libc::c_void,
107 Self::Sm4Gcm(info) => info as *const _ as *const libc::c_void,
108 Self::Sm4Ccm(info) => info as *const _ as *const libc::c_void,
109 }
110 }
111
112 #[inline]
113 pub fn size(&self) -> usize {
115 match self {
116 Self::AesGcm128(info) => size_of_val(info),
117 Self::AesGcm256(info) => size_of_val(info),
118 Self::AesCcm128(info) => size_of_val(info),
119 Self::Chacha20Poly1305(info) => size_of_val(info),
120 Self::Sm4Gcm(info) => size_of_val(info),
121 Self::Sm4Ccm(info) => size_of_val(info),
122 }
123 }
124}
125
126impl CryptoInfo {
127 pub fn from_rustls(
129 cipher_suite: SupportedCipherSuite,
130 (seq, secrets): (u64, ConnectionTrafficSecrets),
131 ) -> Result<CryptoInfo, KtlsCompatibilityError> {
132 let version = match cipher_suite {
133 SupportedCipherSuite::Tls12(..) => TLS_1_2_VERSION_NUMBER,
134 SupportedCipherSuite::Tls13(..) => TLS_1_3_VERSION_NUMBER,
135 };
136
137 Ok(match secrets {
138 ConnectionTrafficSecrets::Aes128Gcm { key, iv } => {
139 match key.as_ref().len() {
145 16 => CryptoInfo::AesGcm128(bindings::tls12_crypto_info_aes_gcm_128 {
146 info: bindings::tls_crypto_info {
147 version,
148 cipher_type: bindings::TLS_CIPHER_AES_GCM_128 as _,
149 },
150 iv: iv
151 .as_ref()
152 .get(4..)
153 .expect("AES-GCM-128 iv is 8 bytes")
154 .try_into()
155 .expect("AES-GCM-128 iv is 8 bytes"),
156 key: key
157 .as_ref()
158 .try_into()
159 .expect("AES-GCM-128 key is 16 bytes"),
160 salt: iv
161 .as_ref()
162 .get(..4)
163 .expect("AES-GCM-128 salt is 4 bytes")
164 .try_into()
165 .expect("AES-GCM-128 salt is 4 bytes"),
166 rec_seq: seq.to_be_bytes(),
167 }),
168 32 => CryptoInfo::AesGcm256(bindings::tls12_crypto_info_aes_gcm_256 {
169 info: bindings::tls_crypto_info {
170 version,
171 cipher_type: bindings::TLS_CIPHER_AES_GCM_256 as _,
172 },
173 iv: iv
174 .as_ref()
175 .get(4..)
176 .expect("AES-GCM-256 iv is 8 bytes")
177 .try_into()
178 .expect("AES-GCM-256 iv is 8 bytes"),
179 key: key
180 .as_ref()
181 .try_into()
182 .expect("AES-GCM-256 key is 32 bytes"),
183 salt: iv
184 .as_ref()
185 .get(..4)
186 .expect("AES-GCM-256 salt is 4 bytes")
187 .try_into()
188 .expect("AES-GCM-256 salt is 4 bytes"),
189 rec_seq: seq.to_be_bytes(),
190 }),
191 _ => unreachable!("GCM key length is not 16 or 32"),
192 }
193 }
194 ConnectionTrafficSecrets::Aes256Gcm { key, iv } => {
195 CryptoInfo::AesGcm256(bindings::tls12_crypto_info_aes_gcm_256 {
196 info: bindings::tls_crypto_info {
197 version,
198 cipher_type: bindings::TLS_CIPHER_AES_GCM_256 as _,
199 },
200 iv: iv
201 .as_ref()
202 .get(4..)
203 .expect("AES-GCM-256 iv is 8 bytes")
204 .try_into()
205 .expect("AES-GCM-256 iv is 8 bytes"),
206 key: key
207 .as_ref()
208 .try_into()
209 .expect("AES-GCM-256 key is 32 bytes"),
210 salt: iv
211 .as_ref()
212 .get(..4)
213 .expect("AES-GCM-256 salt is 4 bytes")
214 .try_into()
215 .expect("AES-GCM-256 salt is 4 bytes"),
216 rec_seq: seq.to_be_bytes(),
217 })
218 }
219 ConnectionTrafficSecrets::Chacha20Poly1305 { key, iv } => {
220 CryptoInfo::Chacha20Poly1305(bindings::tls12_crypto_info_chacha20_poly1305 {
221 info: bindings::tls_crypto_info {
222 version,
223 cipher_type: bindings::TLS_CIPHER_CHACHA20_POLY1305 as _,
224 },
225 iv: iv
226 .as_ref()
227 .try_into()
228 .expect("Chacha20-Poly1305 iv is 12 bytes"),
229 key: key
230 .as_ref()
231 .try_into()
232 .expect("Chacha20-Poly1305 key is 32 bytes"),
233 salt: bindings::__IncompleteArrayField::new(),
234 rec_seq: seq.to_be_bytes(),
235 })
236 }
237 _ => {
238 return Err(KtlsCompatibilityError::UnsupportedCipherSuite(cipher_suite));
239 }
240 })
241 }
242}
243
244const TLS_SET_RECORD_TYPE: libc::c_int = 1;
245const ALERT: u8 = 0x15;
246
247#[cfg_attr(target_pointer_width = "32", repr(C, align(4)))]
249#[cfg_attr(target_pointer_width = "64", repr(C, align(8)))]
250struct Cmsg<const N: usize> {
251 hdr: libc::cmsghdr,
252 data: [u8; N],
253}
254
255impl<const N: usize> Cmsg<N> {
256 fn new(level: i32, typ: i32, data: [u8; N]) -> Self {
257 Self {
258 hdr: libc::cmsghdr {
259 #[allow(clippy::unnecessary_cast)]
261 cmsg_len: (memoffset::offset_of!(Self, data) + N) as _,
262 cmsg_level: level,
263 cmsg_type: typ,
264 },
265 data,
266 }
267 }
268}
269
270pub(crate) fn send_close_notify(fd: RawFd) -> std::io::Result<()> {
271 let mut data = vec![];
272 Message::build_alert(AlertLevel::Warning, AlertDescription::CloseNotify)
273 .payload
274 .encode(&mut data);
275
276 let mut cmsg = Cmsg::new(SOL_TLS, TLS_SET_RECORD_TYPE, [ALERT]);
277
278 let msg = libc::msghdr {
279 msg_name: std::ptr::null_mut(),
280 msg_namelen: 0,
281 msg_iov: &mut libc::iovec {
282 iov_base: data.as_mut_ptr() as _,
283 iov_len: data.len(),
284 },
285 msg_iovlen: 1,
286 msg_control: &mut cmsg as *mut _ as *mut _,
287 msg_controllen: cmsg.hdr.cmsg_len,
288 msg_flags: 0,
289 };
290
291 let ret = unsafe { libc::sendmsg(fd, &msg, 0) };
292 if ret < 0 {
293 return Err(io::Error::last_os_error());
294 }
295 Ok(())
296}