1pub mod cork;
2
3use std::{
4 io::{self, IoSliceMut},
5 os::unix::io::AsRawFd,
6 pin::Pin,
7 task,
8};
9
10use nix::{
11 errno::Errno,
12 sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrIn, TlsGetRecordType},
13};
14use num_enum::FromPrimitive;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17use crate::AsyncReadReady;
18
19pin_project_lite::pin_project! {
21 pub struct KtlsStream<IO>
22 where
23 IO: AsRawFd
24 {
25 #[pin]
26 inner: IO,
27 write_closed: bool,
28 read_closed: bool,
29 drained: Option<(usize, Vec<u8>)>,
30 }
31}
32
33impl<IO> KtlsStream<IO>
34where
35 IO: AsRawFd,
36{
37 pub fn new(inner: IO, drained: Option<Vec<u8>>) -> Self {
38 Self {
39 inner,
40 write_closed: false,
41 read_closed: false,
42 drained: drained.map(|drained| (0, drained)),
43 }
44 }
45
46 pub fn into_raw(self) -> (Option<Vec<u8>>, IO) {
48 (self.drained.map(|(_, drained)| drained), self.inner)
49 }
50
51 pub fn get_ref(&self) -> &IO {
53 &self.inner
54 }
55
56 pub fn get_mut(&mut self) -> &mut IO {
58 &mut self.inner
59 }
60}
61
62#[derive(Debug, PartialEq, Clone, Copy, num_enum::FromPrimitive)]
63#[repr(u8)]
64enum TlsAlertLevel {
65 Warning = 1,
66 Fatal = 2,
67 #[num_enum(catch_all)]
68 Other(u8),
69}
70
71#[derive(Debug, PartialEq, Clone, Copy, num_enum::FromPrimitive)]
72#[repr(u8)]
73enum TlsAlertDescription {
74 CloseNotify = 0,
75 #[num_enum(catch_all)]
76 Other(u8),
77}
78
79impl<'a, IO> AsyncRead for KtlsStream<IO>
80where
81 IO: AsRawFd + AsyncRead + AsyncReadReady<'a>,
82{
83 fn poll_read(
84 self: Pin<&mut Self>,
85 cx: &mut task::Context<'_>,
86 buf: &mut ReadBuf<'_>,
87 ) -> task::Poll<io::Result<()>> {
88 tracing::trace!(buf.remaining = %buf.remaining(), "KtlsStream::poll_read");
89
90 if self.read_closed {
91 return task::Poll::Ready(Ok(()));
92 }
93
94 if buf.remaining() == 0 {
95 return task::Poll::Ready(Ok(()));
96 }
97
98 let mut this = self.project();
99
100 if let Some((drain_index, drained)) = this.drained.as_mut() {
101 let drained = &drained[*drain_index..];
102 let len = std::cmp::min(buf.remaining(), drained.len());
103
104 tracing::trace!(%len, "KtlsStream::poll_read, can take from drain");
105 buf.put_slice(&drained[..len]);
106
107 *drain_index += len;
108 if *drain_index >= drained.len() {
109 tracing::trace!("KtlsStream::poll_read, done draining");
110 *this.drained = None;
111 }
112 cx.waker().wake_by_ref();
113
114 tracing::trace!("KtlsStream::poll_read, returning after drain");
115 return task::Poll::Ready(Ok(()));
116 }
117
118 let read_res = this.inner.as_mut().poll_read(cx, buf);
119 if let task::Poll::Ready(Err(e)) = &read_res {
120 if let Some(5) = e.raw_os_error() {
124 let fd = this.inner.as_raw_fd();
126
127 let mut cmsgspace = Vec::with_capacity(unsafe {
134 libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _
135 });
136
137 let mut iov = [IoSliceMut::new(buf.initialize_unfilled())];
138 let flags = MsgFlags::empty();
139
140 let r = recvmsg::<SockaddrIn>(fd, &mut iov, Some(&mut cmsgspace), flags);
141 let r = match r {
142 Ok(r) => r,
143 Err(Errno::EAGAIN) => {
144 unreachable!("expected a control message, got EAGAIN")
145 }
146 Err(e) => {
147 tracing::trace!(?e, "recvmsg failed");
149 return Err(e.into()).into();
150 }
151 };
152 let cmsg = r
153 .cmsgs()?
154 .next()
155 .expect("we should've received exactly one control message");
156
157 let record_type = match cmsg {
158 ControlMessageOwned::TlsGetRecordType(t) => t,
159 _ => panic!("unexpected cmsg type: {cmsg:#?}"),
160 };
161
162 match record_type {
163 TlsGetRecordType::ChangeCipherSpec => {
164 panic!("change_cipher_spec isn't supported by the ktls crate")
165 }
166 TlsGetRecordType::Alert => {
167 let iov = r.iovs().next().expect("expected data in iovs");
169
170 let (level, description) = match iov {
171 [] => {
172 unreachable!();
174 }
175 &[level] => {
176 (
187 TlsAlertLevel::from_primitive(level),
188 TlsAlertDescription::CloseNotify,
189 )
190 }
191 &[level, description] => (
192 TlsAlertLevel::from_primitive(level),
193 TlsAlertDescription::from_primitive(description),
194 ),
195 _ => {
196 unreachable!(
197 "TLS alerts are exactly 2 bytes, your kTLS is misbehaving"
198 );
199 }
200 };
201
202 match (level, description) {
203 (_, TlsAlertDescription::CloseNotify) | (TlsAlertLevel::Fatal, _) => {
207 tracing::trace!(?level, ?description, "got TLS alert");
208 *this.read_closed = true;
209 *this.write_closed = true;
210 if let Err(e) =
211 crate::ffi::send_close_notify(this.inner.as_raw_fd())
212 {
213 return Err(e).into();
214 }
215 return task::Poll::Ready(Ok(()));
219 }
220 _ => {
221 }
223 }
224 return task::Poll::Ready(Ok(()));
225 }
226 TlsGetRecordType::Handshake => {
227 tracing::trace!(
231 "ignoring handshake message (probably a resumption ticket)"
232 );
233 }
234 TlsGetRecordType::ApplicationData => {
235 unreachable!(
236 "received TLS application in recvmsg, this is supposed to happen in \
237 the poll_read codepath"
238 )
239 }
240 TlsGetRecordType::Unknown(t) => {
241 tracing::trace!("received record_type {t:#?}");
243 }
244 _ => {
245 tracing::trace!("received unsupported record type");
246 }
247 };
248
249 cx.waker().wake_by_ref();
258 return task::Poll::Pending;
259 }
260 }
261
262 read_res
263 }
264}
265
266impl<IO> AsyncWrite for KtlsStream<IO>
267where
268 IO: AsRawFd + AsyncWrite,
269{
270 fn poll_write(
271 self: Pin<&mut Self>,
272 cx: &mut task::Context<'_>,
273 buf: &[u8],
274 ) -> task::Poll<io::Result<usize>> {
275 if self.write_closed {
276 return task::Poll::Ready(Ok(0));
277 }
278
279 self.project().inner.poll_write(cx, buf)
280 }
281
282 fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
283 self.project().inner.poll_flush(cx)
284 }
285
286 fn poll_shutdown(
287 self: Pin<&mut Self>,
288 cx: &mut task::Context<'_>,
289 ) -> task::Poll<io::Result<()>> {
290 let this = self.project();
291
292 if !*this.write_closed {
293 *this.write_closed = true;
296 if let Err(e) = crate::ffi::send_close_notify(this.inner.as_raw_fd()) {
297 return Err(e).into();
298 }
299 }
300
301 this.inner.poll_shutdown(cx)
303 }
304}
305
306impl<IO> AsRawFd for KtlsStream<IO>
307where
308 IO: AsRawFd,
309{
310 fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
311 self.inner.as_raw_fd()
312 }
313}