1pub use tokio_rustls::rustls;
8
9use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
10use std::net::SocketAddr;
11use std::sync::{Arc, Mutex, mpsc as std_mpsc};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
14use tokio::runtime::Handle;
15use tokio::sync::{mpsc, watch};
16use tokio::task::JoinHandle;
17use tokio_rustls::rustls::pki_types::ServerName;
18use tokio_rustls::{TlsAcceptor, TlsConnector};
19
20const DEFAULT_CHUNK_SIZE: usize = 8192;
21
22pub type TlsByteSource = Source<Vec<u8>, NotUsed>;
29
30pub type TlsByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
35
36enum DemandResponse<T> {
37 Item(T),
38 Complete,
39 Error(StreamError),
40}
41
42struct ReadResource {
43 receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
44 cancel: watch::Sender<bool>,
45 task: JoinHandle<()>,
46}
47
48impl Drop for ReadResource {
49 fn drop(&mut self) {
50 let _ = self.cancel.send(true);
51 self.task.abort();
52 }
53}
54
55struct BindResource {
56 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
57 cancel: watch::Sender<bool>,
58 task: JoinHandle<()>,
59}
60
61impl Drop for BindResource {
62 fn drop(&mut self) {
63 let _ = self.cancel.send(true);
64 self.task.abort();
65 }
66}
67
68fn io_error(error: std::io::Error) -> StreamError {
69 StreamError::Failed(error.to_string())
70}
71
72fn abrupt_termination() -> StreamError {
73 StreamError::AbruptTermination
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct TlsConnection {
79 pub local_addr: SocketAddr,
80 pub remote_addr: SocketAddr,
81}
82
83impl TlsConnection {
84 #[must_use]
85 pub fn local_addr(&self) -> SocketAddr {
86 self.local_addr
87 }
88
89 #[must_use]
90 pub fn remote_addr(&self) -> SocketAddr {
91 self.remote_addr
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub struct TlsBinding {
98 pub local_addr: SocketAddr,
99}
100
101impl TlsBinding {
102 #[must_use]
103 pub fn local_addr(&self) -> SocketAddr {
104 self.local_addr
105 }
106}
107
108pub struct TlsIncomingConnection {
110 connection: TlsConnection,
111 source: TlsByteSource,
112 sink: TlsByteSink,
113}
114
115impl TlsIncomingConnection {
116 #[must_use]
117 pub fn local_addr(&self) -> SocketAddr {
118 self.connection.local_addr
119 }
120
121 #[must_use]
122 pub fn remote_addr(&self) -> SocketAddr {
123 self.connection.remote_addr
124 }
125
126 #[must_use]
127 pub fn connection(&self) -> TlsConnection {
128 self.connection
129 }
130
131 #[must_use]
132 pub fn into_parts(self) -> (TlsByteSource, TlsByteSink) {
133 (self.source, self.sink)
134 }
135
136 #[must_use]
137 pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
138 Flow::from_sink_and_source_coupled(self.sink, self.source)
139 .map_materialized_value(|_| NotUsed)
140 }
141}
142
143pub struct TokioTls;
145
146pub type Tls = TokioTls;
148
149impl TokioTls {
150 #[must_use]
156 pub fn outgoing_connection<A>(
157 addr: A,
158 server_name: ServerName<'static>,
159 client_config: Arc<rustls::ClientConfig>,
160 chunk_size: usize,
161 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
162 where
163 A: ToSocketAddrs + Clone + Send + Sync + 'static,
164 {
165 assert!(chunk_size > 0, "chunk size must be greater than zero");
166 Flow::future_flow(move || {
167 let addr = addr.clone();
168 let server_name = server_name.clone();
169 let client_config = Arc::clone(&client_config);
170 async move {
171 let handle = Handle::current();
172 let tcp = TcpStream::connect(addr).await.map_err(io_error)?;
173 let connection = TlsConnection {
174 local_addr: tcp.local_addr().map_err(io_error)?,
175 remote_addr: tcp.peer_addr().map_err(io_error)?,
176 };
177 let tls = TlsConnector::from(client_config)
178 .connect(server_name, tcp)
179 .await
180 .map_err(io_error)?;
181 Ok(tls_flow_from_stream(tls, connection, handle, chunk_size))
182 }
183 })
184 }
185
186 #[must_use]
188 pub fn outgoing_connection_default<A>(
189 addr: A,
190 server_name: ServerName<'static>,
191 client_config: Arc<rustls::ClientConfig>,
192 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
193 where
194 A: ToSocketAddrs + Clone + Send + Sync + 'static,
195 {
196 Self::outgoing_connection(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
197 }
198
199 #[must_use]
205 pub fn bind<A>(
206 addr: A,
207 server_config: Arc<rustls::ServerConfig>,
208 chunk_size: usize,
209 ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
210 where
211 A: ToSocketAddrs + Clone + Send + Sync + 'static,
212 {
213 assert!(chunk_size > 0, "chunk size must be greater than zero");
214 Source::lazy_future_source(move || {
215 let addr = addr.clone();
216 let server_config = Arc::clone(&server_config);
217 async move {
218 let handle = Handle::current();
219 let listener = TcpListener::bind(addr).await.map_err(io_error)?;
220 let local_addr = listener.local_addr().map_err(io_error)?;
221 Ok(tls_bind_source(
222 listener,
223 server_config,
224 local_addr,
225 handle,
226 chunk_size,
227 ))
228 }
229 })
230 }
231
232 #[must_use]
234 pub fn bind_default<A>(
235 addr: A,
236 server_config: Arc<rustls::ServerConfig>,
237 ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
238 where
239 A: ToSocketAddrs + Clone + Send + Sync + 'static,
240 {
241 Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
242 }
243}
244
245pub(crate) fn tls_flow_from_stream<S>(
246 stream: S,
247 connection: TlsConnection,
248 handle: Handle,
249 chunk_size: usize,
250) -> Flow<Vec<u8>, Vec<u8>, TlsConnection>
251where
252 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
253{
254 let (reader, writer) = tokio::io::split(stream);
255 let source = single_use_async_read_source(reader, handle.clone(), chunk_size);
256 let sink = single_use_async_write_sink(writer, handle);
257 Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
258}
259
260fn tls_incoming_connection<S>(
261 stream: S,
262 connection: TlsConnection,
263 handle: Handle,
264 chunk_size: usize,
265) -> TlsIncomingConnection
266where
267 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
268{
269 let (reader, writer) = tokio::io::split(stream);
270 TlsIncomingConnection {
271 connection,
272 source: single_use_async_read_source(reader, handle.clone(), chunk_size),
273 sink: single_use_async_write_sink(writer, handle),
274 }
275}
276
277fn single_use_async_read_source<R>(reader: R, handle: Handle, chunk_size: usize) -> TlsByteSource
278where
279 R: AsyncRead + Unpin + Send + 'static,
280{
281 let reader = Arc::new(Mutex::new(Some(reader)));
282 Source::unfold_resource(
283 {
284 let reader = Arc::clone(&reader);
285 move || {
286 let reader = reader
287 .lock()
288 .expect("single-use TLS reader poisoned")
289 .take()
290 .ok_or_else(|| StreamError::Failed("TLS reader already materialized".into()))?;
291 let (sender, receiver) = mpsc::channel(1);
292 let (cancel_sender, cancel_receiver) = watch::channel(false);
293 let task = handle.spawn(run_read_task(reader, chunk_size, sender, cancel_receiver));
294 Ok(ReadResource {
295 receiver,
296 cancel: cancel_sender,
297 task,
298 })
299 }
300 },
301 |resource| match resource.receiver.blocking_recv() {
302 Some(DemandResponse::Item(chunk)) => Ok(Some(chunk)),
303 Some(DemandResponse::Complete) => Ok(None),
304 Some(DemandResponse::Error(error)) => Err(error),
305 None => Err(abrupt_termination()),
306 },
307 close_read_resource,
308 )
309}
310
311fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
312 let _ = resource.cancel.send(true);
313 resource.task.abort();
314 Ok(())
315}
316
317async fn run_read_task<R>(
318 mut reader: R,
319 chunk_size: usize,
320 sender: mpsc::Sender<DemandResponse<Vec<u8>>>,
321 mut cancel: watch::Receiver<bool>,
322) where
323 R: AsyncRead + Unpin + Send + 'static,
324{
325 let mut buffer = vec![0_u8; chunk_size];
326 let mut pending_tail = Vec::with_capacity(chunk_size);
327
328 loop {
329 let read = tokio::select! {
330 read = reader.read(&mut buffer) => read,
331 changed = cancel.changed() => {
332 let _ = changed;
333 return;
334 }
335 };
336
337 match read {
338 Ok(0) => {
339 if !pending_tail.is_empty()
340 && !send_read_item(
341 &sender,
342 DemandResponse::Item(std::mem::take(&mut pending_tail)),
343 &mut cancel,
344 )
345 .await
346 {
347 return;
348 }
349 let _ = send_read_item(&sender, DemandResponse::Complete, &mut cancel).await;
350 return;
351 }
352 Ok(read) => {
353 if !send_read_chunks(
354 &sender,
355 chunk_size,
356 &mut pending_tail,
357 &buffer[..read],
358 &mut cancel,
359 )
360 .await
361 {
362 return;
363 }
364 }
365 Err(error) => {
366 let _ =
367 send_read_item(&sender, DemandResponse::Error(io_error(error)), &mut cancel)
368 .await;
369 return;
370 }
371 }
372 }
373}
374
375async fn send_read_chunks(
376 sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
377 chunk_size: usize,
378 pending_tail: &mut Vec<u8>,
379 read_buffer: &[u8],
380 cancel: &mut watch::Receiver<bool>,
381) -> bool {
382 let mut offset = 0;
383 if !pending_tail.is_empty() {
384 let needed = chunk_size - pending_tail.len();
385 let take = needed.min(read_buffer.len());
386 pending_tail.extend_from_slice(&read_buffer[..take]);
387 offset += take;
388 if pending_tail.len() == chunk_size
389 && !send_read_item(
390 sender,
391 DemandResponse::Item(std::mem::take(pending_tail)),
392 cancel,
393 )
394 .await
395 {
396 return false;
397 }
398 }
399
400 while offset + chunk_size <= read_buffer.len() {
401 let next = offset + chunk_size;
402 if !send_read_item(
403 sender,
404 DemandResponse::Item(read_buffer[offset..next].to_vec()),
405 cancel,
406 )
407 .await
408 {
409 return false;
410 }
411 offset = next;
412 }
413
414 if offset < read_buffer.len() {
415 pending_tail.extend_from_slice(&read_buffer[offset..]);
416 }
417 true
418}
419
420async fn send_read_item<T>(
421 sender: &mpsc::Sender<DemandResponse<T>>,
422 item: DemandResponse<T>,
423 cancel: &mut watch::Receiver<bool>,
424) -> bool
425where
426 T: Send + 'static,
427{
428 tokio::select! {
429 result = sender.send(item) => result.is_ok(),
430 changed = cancel.changed() => {
431 let _ = changed;
432 false
433 }
434 }
435}
436
437fn single_use_async_write_sink<W>(writer: W, handle: Handle) -> TlsByteSink
438where
439 W: AsyncWrite + Unpin + Send + 'static,
440{
441 let writer = Arc::new(Mutex::new(Some(writer)));
442 Flow::<Vec<u8>, Vec<u8>>::identity()
443 .map_with_resource(
444 {
445 let writer = Arc::clone(&writer);
446 move || {
447 writer
448 .lock()
449 .expect("single-use TLS writer poisoned")
450 .take()
451 .ok_or_else(|| {
452 StreamError::Failed("TLS writer already materialized".into())
453 })
454 }
455 },
456 {
457 let handle = handle.clone();
458 move |writer, chunk| {
459 handle.block_on(async {
460 writer.write_all(&chunk).await.map_err(io_error)?;
461 writer.flush().await.map_err(io_error)
462 })?;
463 Ok(())
464 }
465 },
466 move |mut writer| {
467 handle.block_on(async {
468 writer.flush().await.map_err(io_error)?;
469 writer.shutdown().await.map_err(io_error)
470 })?;
471 Ok(None)
472 },
473 )
474 .to_mat(Sink::ignore(), Keep::right)
475}
476
477fn tls_bind_source(
478 listener: TcpListener,
479 server_config: Arc<rustls::ServerConfig>,
480 local_addr: SocketAddr,
481 handle: Handle,
482 chunk_size: usize,
483) -> Source<TlsIncomingConnection, TlsBinding> {
484 let listener = Arc::new(Mutex::new(Some(listener)));
485 Source::unfold_resource(
486 {
487 let listener = Arc::clone(&listener);
488 let handle = handle.clone();
489 move || {
490 let listener = listener
491 .lock()
492 .expect("single-use TLS listener poisoned")
493 .take()
494 .ok_or_else(|| {
495 StreamError::Failed("TLS listener already materialized".into())
496 })?;
497 let (demand_sender, demand_receiver) = mpsc::channel(1);
498 let (cancel_sender, cancel_receiver) = watch::channel(false);
499 let task = handle.spawn(run_tls_bind_task(
500 listener,
501 Arc::clone(&server_config),
502 local_addr,
503 chunk_size,
504 handle.clone(),
505 demand_receiver,
506 cancel_receiver,
507 ));
508 Ok(BindResource {
509 demands: demand_sender,
510 cancel: cancel_sender,
511 task,
512 })
513 }
514 },
515 |resource| {
516 let (reply_sender, reply_receiver) = std_mpsc::channel();
517 resource
518 .demands
519 .blocking_send(reply_sender)
520 .map_err(|_| abrupt_termination())?;
521 match reply_receiver.recv() {
522 Ok(DemandResponse::Item(connection)) => Ok(Some(connection)),
523 Ok(DemandResponse::Complete) => Ok(None),
524 Ok(DemandResponse::Error(error)) => Err(error),
525 Err(_) => Err(abrupt_termination()),
526 }
527 },
528 close_bind_resource,
529 )
530 .map_materialized_value(move |_| TlsBinding { local_addr })
531}
532
533fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
534 let _ = resource.cancel.send(true);
535 resource.task.abort();
536 Ok(())
537}
538
539async fn run_tls_bind_task(
540 listener: TcpListener,
541 server_config: Arc<rustls::ServerConfig>,
542 local_addr: SocketAddr,
543 chunk_size: usize,
544 handle: Handle,
545 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
546 mut cancel: watch::Receiver<bool>,
547) {
548 let acceptor = TlsAcceptor::from(server_config);
549 loop {
550 let reply = tokio::select! {
551 demand = demands.recv() => match demand {
552 Some(reply) => reply,
553 None => return,
554 },
555 changed = cancel.changed() => {
556 let _ = changed;
557 return;
558 }
559 };
560
561 let (tcp, remote_addr) = loop {
562 let accepted = tokio::select! {
563 accepted = listener.accept() => accepted,
564 changed = cancel.changed() => {
565 let _ = changed;
566 return;
567 }
568 };
569
570 match accepted {
571 Ok(accepted) => break accepted,
572 Err(error) if is_transient_accept_error(&error) => continue,
573 Err(error) => {
574 let _ = reply.send(DemandResponse::Error(io_error(error)));
575 return;
576 }
577 }
578 };
579
580 let connection = TlsConnection {
581 local_addr: tcp.local_addr().unwrap_or(local_addr),
582 remote_addr,
583 };
584 let accepted = tokio::select! {
585 accepted = acceptor.accept(tcp) => accepted,
586 changed = cancel.changed() => {
587 let _ = changed;
588 return;
589 }
590 };
591
592 match accepted {
593 Ok(stream) => {
594 let incoming =
595 tls_incoming_connection(stream, connection, handle.clone(), chunk_size);
596 if reply.send(DemandResponse::Item(incoming)).is_err() {
597 return;
598 }
599 }
600 Err(error) => {
601 let _ = reply.send(DemandResponse::Error(io_error(error)));
602 return;
603 }
604 }
605 }
606}
607
608fn is_transient_accept_error(error: &std::io::Error) -> bool {
609 matches!(
610 error.kind(),
611 std::io::ErrorKind::Interrupted
612 | std::io::ErrorKind::ConnectionAborted
613 | std::io::ErrorKind::ConnectionReset
614 ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
615}
616
617#[cfg(target_os = "linux")]
618fn is_transient_accept_errno(code: i32) -> bool {
619 matches!(code, 4 | 103 | 104)
620}
621
622#[cfg(not(target_os = "linux"))]
623fn is_transient_accept_errno(_code: i32) -> bool {
624 false
625}