1pub use tokio_rustls::rustls;
8
9use crate::async_carrier::{self, AsyncCommandSender, DemandBatcher};
10use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
11use std::net::SocketAddr;
12use std::sync::{Arc, Mutex, atomic::AtomicUsize, mpsc as std_mpsc};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio::runtime::Handle;
16use tokio::sync::{mpsc, watch};
17use tokio::task::JoinHandle;
18use tokio_rustls::rustls::pki_types::ServerName;
19use tokio_rustls::{TlsAcceptor, TlsConnector};
20
21const DEFAULT_CHUNK_SIZE: usize = 8192;
22const DEFAULT_RECEIVE_BUFFER: usize = 64;
23
24static ACTIVE_TLS_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
25
26pub type TlsByteSource = Source<Vec<u8>, NotUsed>;
33
34pub type TlsByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
39
40enum DemandResponse<T> {
41 Item(T),
42 Complete,
43 Error(StreamError),
44}
45
46struct ReadResource {
47 receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
48 carrier: TlsCarrier,
49 demand: DemandBatcher,
50 pending: Option<DemandResponse<Vec<u8>>>,
51}
52
53impl Drop for ReadResource {
54 fn drop(&mut self) {
55 self.carrier.close_read();
56 }
57}
58
59enum TlsCarrierCommand {
60 Demand(usize),
61 SendOne(Vec<u8>),
62 SendBatch(Vec<Vec<u8>>),
63 CloseRead,
64 CloseWrite {
65 ack: std_mpsc::Sender<StreamResult<()>>,
66 },
67}
68
69#[derive(Clone)]
70struct TlsCarrier {
71 inner: Arc<TlsCarrierInner>,
72}
73
74struct TlsCarrierInner {
75 commands: AsyncCommandSender<TlsCarrierCommand>,
76 send_errors: Mutex<std_mpsc::Receiver<StreamError>>,
77 task: Mutex<Option<JoinHandle<()>>>,
78 _execution: async_carrier::ShardedTokioCarrierExecution,
79}
80
81impl Drop for TlsCarrierInner {
82 fn drop(&mut self) {
83 if let Some(task) = self.task.lock().expect("TLS carrier task poisoned").take() {
84 task.abort();
85 }
86 }
87}
88
89impl TlsCarrier {
90 fn close_read(&self) {
91 let _ = self.inner.commands.try_send(TlsCarrierCommand::CloseRead);
92 }
93
94 fn request_demand(&self, demand: usize) -> StreamResult<()> {
95 self.inner
96 .commands
97 .send_or_blocking(TlsCarrierCommand::Demand(demand))
98 }
99
100 fn send_items(&self, items: Vec<Vec<u8>>) -> StreamResult<()> {
101 self.check_send_error()?;
102 self.inner
103 .commands
104 .send_or_blocking(TlsCarrierCommand::SendBatch(items))
105 .map_err(|error| StreamError::Failed(format!("TLS send batch failed: {error:?}")))
106 }
107
108 fn send_one(&self, item: Vec<u8>) -> StreamResult<()> {
109 self.check_send_error()?;
110 self.inner
111 .commands
112 .send_or_blocking(TlsCarrierCommand::SendOne(item))
113 .map_err(|error| StreamError::Failed(format!("TLS send failed: {error:?}")))
114 }
115
116 fn close_write(&self) -> StreamResult<()> {
117 self.check_send_error()?;
118 let (ack_sender, ack_receiver) = std_mpsc::channel();
119 if self
120 .inner
121 .commands
122 .send_or_blocking(TlsCarrierCommand::CloseWrite { ack: ack_sender })
123 .is_err()
124 {
125 return Ok(());
126 }
127 match ack_receiver.recv() {
128 Ok(result) => result,
129 Err(_) => Err(abrupt_termination()),
130 }?;
131 self.check_send_error()
132 }
133
134 fn check_send_error(&self) -> StreamResult<()> {
135 match self
136 .inner
137 .send_errors
138 .lock()
139 .expect("TLS carrier send error receiver poisoned")
140 .try_recv()
141 {
142 Ok(error) => Err(error),
143 Err(std_mpsc::TryRecvError::Empty) | Err(std_mpsc::TryRecvError::Disconnected) => {
144 Ok(())
145 }
146 }
147 }
148}
149
150struct SendResource {
151 carrier: TlsCarrier,
152 pending: Vec<Vec<u8>>,
153 batch_size: usize,
154}
155
156struct BindResource {
157 demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
158 cancel: watch::Sender<bool>,
159 task: JoinHandle<()>,
160}
161
162impl Drop for BindResource {
163 fn drop(&mut self) {
164 let _ = self.cancel.send(true);
165 self.task.abort();
166 }
167}
168
169fn io_error(error: std::io::Error) -> StreamError {
170 StreamError::Failed(error.to_string())
171}
172
173fn abrupt_termination() -> StreamError {
174 StreamError::AbruptTermination
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub struct TlsConnection {
180 pub local_addr: SocketAddr,
181 pub remote_addr: SocketAddr,
182}
183
184impl TlsConnection {
185 #[must_use]
186 pub fn local_addr(&self) -> SocketAddr {
187 self.local_addr
188 }
189
190 #[must_use]
191 pub fn remote_addr(&self) -> SocketAddr {
192 self.remote_addr
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub struct TlsBinding {
199 pub local_addr: SocketAddr,
200}
201
202impl TlsBinding {
203 #[must_use]
204 pub fn local_addr(&self) -> SocketAddr {
205 self.local_addr
206 }
207}
208
209pub struct TlsIncomingConnection {
211 connection: TlsConnection,
212 source: TlsByteSource,
213 sink: TlsByteSink,
214}
215
216impl TlsIncomingConnection {
217 #[must_use]
218 pub fn local_addr(&self) -> SocketAddr {
219 self.connection.local_addr
220 }
221
222 #[must_use]
223 pub fn remote_addr(&self) -> SocketAddr {
224 self.connection.remote_addr
225 }
226
227 #[must_use]
228 pub fn connection(&self) -> TlsConnection {
229 self.connection
230 }
231
232 #[must_use]
233 pub fn into_parts(self) -> (TlsByteSource, TlsByteSink) {
234 (self.source, self.sink)
235 }
236
237 #[must_use]
238 pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
239 Flow::from_sink_and_source_coupled(self.sink, self.source)
240 .map_materialized_value(|_| NotUsed)
241 }
242}
243
244pub struct TokioTls;
246
247pub type Tls = TokioTls;
249
250impl TokioTls {
251 #[must_use]
257 pub fn outgoing_connection<A>(
258 addr: A,
259 server_name: ServerName<'static>,
260 client_config: Arc<rustls::ClientConfig>,
261 chunk_size: usize,
262 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
263 where
264 A: ToSocketAddrs + Clone + Send + Sync + 'static,
265 {
266 assert!(chunk_size > 0, "chunk size must be greater than zero");
267 Flow::future_flow(move || {
268 let addr = addr.clone();
269 let server_name = server_name.clone();
270 let client_config = Arc::clone(&client_config);
271 async move {
272 let handle = Handle::current();
273 tls_client_connect(addr, server_name, client_config, handle, chunk_size).await
274 }
275 })
276 }
277
278 #[must_use]
280 pub fn outgoing_connection_default<A>(
281 addr: A,
282 server_name: ServerName<'static>,
283 client_config: Arc<rustls::ClientConfig>,
284 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
285 where
286 A: ToSocketAddrs + Clone + Send + Sync + 'static,
287 {
288 Self::outgoing_connection(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
289 }
290
291 #[must_use]
297 pub fn bind<A>(
298 addr: A,
299 server_config: Arc<rustls::ServerConfig>,
300 chunk_size: usize,
301 ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
302 where
303 A: ToSocketAddrs + Clone + Send + Sync + 'static,
304 {
305 assert!(chunk_size > 0, "chunk size must be greater than zero");
306 Source::lazy_future_source(move || {
307 let addr = addr.clone();
308 let server_config = Arc::clone(&server_config);
309 async move {
310 let handle = Handle::current();
311 let listener = TcpListener::bind(addr).await.map_err(io_error)?;
312 let local_addr = listener.local_addr().map_err(io_error)?;
313 Ok(tls_bind_source(
314 listener,
315 server_config,
316 local_addr,
317 handle,
318 chunk_size,
319 ))
320 }
321 })
322 }
323
324 #[must_use]
326 pub fn bind_default<A>(
327 addr: A,
328 server_config: Arc<rustls::ServerConfig>,
329 ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
330 where
331 A: ToSocketAddrs + Clone + Send + Sync + 'static,
332 {
333 Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
334 }
335}
336
337pub(crate) fn tls_flow_from_stream_with_execution<S>(
338 stream: S,
339 connection: TlsConnection,
340 execution: async_carrier::ShardedTokioCarrierExecution,
341 chunk_size: usize,
342) -> Flow<Vec<u8>, Vec<u8>, TlsConnection>
343where
344 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
345{
346 let (source, sink) = single_use_tls_halves(stream, execution, chunk_size);
347 Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
348}
349
350fn tls_incoming_connection<S>(
351 stream: S,
352 connection: TlsConnection,
353 execution: async_carrier::ShardedTokioCarrierExecution,
354 chunk_size: usize,
355) -> TlsIncomingConnection
356where
357 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
358{
359 let (source, sink) = single_use_tls_halves(stream, execution, chunk_size);
360 TlsIncomingConnection {
361 connection,
362 source,
363 sink,
364 }
365}
366
367fn single_use_tls_halves<S>(
368 stream: S,
369 execution: async_carrier::ShardedTokioCarrierExecution,
370 chunk_size: usize,
371) -> (TlsByteSource, TlsByteSink)
372where
373 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
374{
375 let (carrier, receiver) =
376 start_tls_carrier(stream, execution, chunk_size, DEFAULT_RECEIVE_BUFFER);
377 let source =
378 single_use_tls_source_from_carrier(carrier.clone(), receiver, DEFAULT_RECEIVE_BUFFER);
379 let sink = single_use_tls_sink_from_carrier(carrier, 1);
380 (source, sink)
381}
382
383fn single_use_tls_source_from_carrier(
384 carrier: TlsCarrier,
385 receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
386 receive_buffer: usize,
387) -> TlsByteSource {
388 let receiver = Arc::new(Mutex::new(Some(receiver)));
389 Source::unfold_resource(
390 {
391 let receiver = Arc::clone(&receiver);
392 move || {
393 let receiver = receiver
394 .lock()
395 .expect("single-use TLS receiver poisoned")
396 .take()
397 .ok_or_else(|| StreamError::Failed("TLS source already materialized".into()))?;
398 let demand = DemandBatcher::new(receive_buffer);
399 let pending = match carrier.request_demand(demand.initial()) {
400 Ok(()) => None,
401 Err(error) => match receiver.try_recv() {
402 Ok(response) => Some(response),
403 Err(std_mpsc::TryRecvError::Empty) => return Err(error),
404 Err(std_mpsc::TryRecvError::Disconnected) => {
405 return Err(abrupt_termination());
406 }
407 },
408 };
409 Ok(ReadResource {
410 receiver,
411 carrier: carrier.clone(),
412 demand,
413 pending,
414 })
415 }
416 },
417 read_next_chunk,
418 close_read_resource,
419 )
420}
421
422fn read_next_chunk(resource: &mut ReadResource) -> StreamResult<Option<Vec<u8>>> {
423 let response = match resource.pending.take() {
424 Some(response) => response,
425 None => resource.receiver.recv().map_err(|_| abrupt_termination())?,
426 };
427 match response {
428 DemandResponse::Item(chunk) => {
429 if let Some(demand) = resource.demand.record_consumed() {
430 let _ = resource.carrier.request_demand(demand);
431 }
432 Ok(Some(chunk))
433 }
434 DemandResponse::Complete => Ok(None),
435 DemandResponse::Error(error) => Err(error),
436 }
437}
438
439fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
440 resource.carrier.close_read();
441 Ok(())
442}
443
444fn start_tls_carrier<S>(
445 stream: S,
446 execution: async_carrier::ShardedTokioCarrierExecution,
447 chunk_size: usize,
448 receive_buffer: usize,
449) -> (TlsCarrier, std_mpsc::Receiver<DemandResponse<Vec<u8>>>)
450where
451 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
452{
453 let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
454 let (commands, command_receiver) = async_carrier::command_channel(command_capacity, "TLS");
455 let (send_error_sender, send_error_receiver) = std_mpsc::channel();
456 let (receive_sender, receive_receiver) =
457 std_mpsc::sync_channel(receive_buffer.saturating_add(1));
458 let (reader, writer) = tokio::io::split(stream);
459 let command_keepalive = commands.clone();
460 let task = execution.handle().spawn(run_tls_carrier_task(
461 reader,
462 writer,
463 chunk_size,
464 receive_sender,
465 send_error_sender,
466 command_keepalive,
467 command_receiver,
468 ));
469 (
470 TlsCarrier {
471 inner: Arc::new(TlsCarrierInner {
472 commands,
473 send_errors: Mutex::new(send_error_receiver),
474 task: Mutex::new(Some(task)),
475 _execution: execution,
476 }),
477 },
478 receive_receiver,
479 )
480}
481
482async fn run_tls_carrier_task<R, W>(
483 mut reader: R,
484 mut writer: W,
485 chunk_size: usize,
486 receive_sender: std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
487 send_error_sender: std_mpsc::Sender<StreamError>,
488 _command_keepalive: AsyncCommandSender<TlsCarrierCommand>,
489 mut commands: mpsc::Receiver<TlsCarrierCommand>,
490) where
491 R: AsyncRead + Unpin + Send + 'static,
492 W: AsyncWrite + Unpin + Send + 'static,
493{
494 let mut buffer = vec![0_u8; chunk_size];
495 let mut pending_tail = Vec::with_capacity(chunk_size);
496 let mut requested = 0_usize;
497 let mut read_open = true;
498 let mut write_open = true;
499
500 loop {
501 if !read_open && !write_open {
502 return;
503 }
504
505 if read_open && requested > 0 {
506 tokio::select! {
507 biased;
508 command = commands.recv() => {
509 let Some(command) = command else {
510 return;
511 };
512 if !handle_tls_carrier_command(
513 &mut writer,
514 command,
515 &send_error_sender,
516 &mut read_open,
517 &mut write_open,
518 &mut requested,
519 ).await {
520 return;
521 }
522 }
523 read = reader.read(&mut buffer) => {
524 match read {
525 Ok(0) => {
526 if !pending_tail.is_empty() {
527 match try_send_tls_read_response(
528 &receive_sender,
529 DemandResponse::Item(std::mem::take(&mut pending_tail)),
530 ) {
531 TlsQueueOutcome::Queued => {
532 requested = requested.saturating_sub(1);
533 }
534 TlsQueueOutcome::Closed => {
535 read_open = false;
536 continue;
537 }
538 TlsQueueOutcome::Full => {
539 report_tls_read_error(
540 &receive_sender,
541 &send_error_sender,
542 tls_receive_buffer_overflow(),
543 );
544 return;
545 }
546 }
547 }
548 match try_send_tls_read_response(
549 &receive_sender,
550 DemandResponse::Complete,
551 ) {
552 TlsQueueOutcome::Queued | TlsQueueOutcome::Closed => {
553 read_open = false;
554 }
555 TlsQueueOutcome::Full => {
556 report_tls_read_error(
557 &receive_sender,
558 &send_error_sender,
559 tls_receive_buffer_overflow(),
560 );
561 return;
562 }
563 }
564 }
565 Ok(read) => {
566 match queue_tls_read_chunks(
567 &receive_sender,
568 &send_error_sender,
569 chunk_size,
570 &mut pending_tail,
571 &buffer[..read],
572 ) {
573 TlsReadQueueResult::Queued(queued) => {
574 requested = requested.saturating_sub(queued);
575 }
576 TlsReadQueueResult::Closed => {
577 read_open = false;
578 }
579 TlsReadQueueResult::Failed => {
580 return;
581 }
582 }
583 }
584 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
585 Err(error) => {
586 report_tls_read_error(
587 &receive_sender,
588 &send_error_sender,
589 io_error(error),
590 );
591 return;
592 }
593 }
594 }
595 }
596 } else {
597 let Some(command) = commands.recv().await else {
598 return;
599 };
600 if !handle_tls_carrier_command(
601 &mut writer,
602 command,
603 &send_error_sender,
604 &mut read_open,
605 &mut write_open,
606 &mut requested,
607 )
608 .await
609 {
610 return;
611 }
612 }
613 }
614}
615
616async fn handle_tls_carrier_command<W>(
617 writer: &mut W,
618 command: TlsCarrierCommand,
619 send_error_sender: &std_mpsc::Sender<StreamError>,
620 read_open: &mut bool,
621 write_open: &mut bool,
622 requested: &mut usize,
623) -> bool
624where
625 W: AsyncWrite + Unpin,
626{
627 match command {
628 TlsCarrierCommand::Demand(demand) => {
629 *requested = requested.saturating_add(demand);
630 true
631 }
632 TlsCarrierCommand::SendOne(chunk) => {
633 if !*write_open {
634 report_tls_write_error(
635 send_error_sender,
636 StreamError::Failed("TLS write side is closed".to_owned()),
637 );
638 return *read_open;
639 }
640 if write_one_tls_chunk(writer, send_error_sender, &chunk).await {
641 true
642 } else {
643 *write_open = false;
644 *read_open
645 }
646 }
647 TlsCarrierCommand::SendBatch(chunks) => {
648 if !*write_open {
649 report_tls_write_error(
650 send_error_sender,
651 StreamError::Failed("TLS write side is closed".to_owned()),
652 );
653 return *read_open;
654 }
655 for chunk in &chunks {
656 if let Err(error) = writer.write_all(chunk).await.map_err(io_error) {
657 report_tls_write_error(send_error_sender, error);
658 *write_open = false;
659 return *read_open;
660 }
661 }
662 if let Err(error) = writer.flush().await.map_err(io_error) {
663 report_tls_write_error(send_error_sender, error);
664 *write_open = false;
665 return *read_open;
666 }
667 true
668 }
669 TlsCarrierCommand::CloseRead => {
670 *read_open = false;
671 true
672 }
673 TlsCarrierCommand::CloseWrite { ack } => {
674 *write_open = false;
675 let result = close_tls_writer(writer).await;
676 match result {
677 Ok(()) => {
678 let _ = ack.send(Ok(()));
679 true
680 }
681 Err(error) => {
682 report_tls_write_error(send_error_sender, error.clone());
683 let _ = ack.send(Err(error));
684 *read_open
685 }
686 }
687 }
688 }
689}
690
691async fn write_one_tls_chunk<W>(
692 writer: &mut W,
693 send_error_sender: &std_mpsc::Sender<StreamError>,
694 chunk: &[u8],
695) -> bool
696where
697 W: AsyncWrite + Unpin,
698{
699 if let Err(error) = writer.write_all(chunk).await.map_err(io_error) {
700 report_tls_write_error(send_error_sender, error);
701 return false;
702 }
703 if let Err(error) = writer.flush().await.map_err(io_error) {
704 report_tls_write_error(send_error_sender, error);
705 return false;
706 }
707 true
708}
709
710async fn close_tls_writer<W>(writer: &mut W) -> StreamResult<()>
711where
712 W: AsyncWrite + Unpin,
713{
714 writer.flush().await.map_err(io_error)?;
715 writer.shutdown().await.map_err(io_error)
716}
717
718enum TlsReadQueueResult {
719 Queued(usize),
720 Closed,
721 Failed,
722}
723
724enum TlsQueueOutcome {
725 Queued,
726 Full,
727 Closed,
728}
729
730fn queue_tls_read_chunks(
731 sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
732 send_error_sender: &std_mpsc::Sender<StreamError>,
733 chunk_size: usize,
734 pending_tail: &mut Vec<u8>,
735 read_buffer: &[u8],
736) -> TlsReadQueueResult {
737 let mut offset = 0;
738 let mut queued = 0_usize;
739 if !pending_tail.is_empty() {
740 let needed = chunk_size - pending_tail.len();
741 let take = needed.min(read_buffer.len());
742 pending_tail.extend_from_slice(&read_buffer[..take]);
743 offset += take;
744 if pending_tail.len() == chunk_size {
745 match try_send_tls_read_response(
746 sender,
747 DemandResponse::Item(std::mem::take(pending_tail)),
748 ) {
749 TlsQueueOutcome::Queued => queued += 1,
750 TlsQueueOutcome::Closed => return TlsReadQueueResult::Closed,
751 TlsQueueOutcome::Full => {
752 report_tls_read_error(sender, send_error_sender, tls_receive_buffer_overflow());
753 return TlsReadQueueResult::Failed;
754 }
755 }
756 }
757 }
758
759 while offset + chunk_size <= read_buffer.len() {
760 let next = offset + chunk_size;
761 match try_send_tls_read_response(
762 sender,
763 DemandResponse::Item(read_buffer[offset..next].to_vec()),
764 ) {
765 TlsQueueOutcome::Queued => queued += 1,
766 TlsQueueOutcome::Closed => return TlsReadQueueResult::Closed,
767 TlsQueueOutcome::Full => {
768 report_tls_read_error(sender, send_error_sender, tls_receive_buffer_overflow());
769 return TlsReadQueueResult::Failed;
770 }
771 }
772 offset = next;
773 }
774
775 if offset < read_buffer.len() {
776 pending_tail.extend_from_slice(&read_buffer[offset..]);
777 }
778 TlsReadQueueResult::Queued(queued)
779}
780
781fn try_send_tls_read_response(
782 sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
783 item: DemandResponse<Vec<u8>>,
784) -> TlsQueueOutcome {
785 match sender.try_send(item) {
786 Ok(()) => TlsQueueOutcome::Queued,
787 Err(std_mpsc::TrySendError::Full(_)) => TlsQueueOutcome::Full,
788 Err(std_mpsc::TrySendError::Disconnected(_)) => TlsQueueOutcome::Closed,
789 }
790}
791
792fn report_tls_read_error(
793 receive_sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
794 send_error_sender: &std_mpsc::Sender<StreamError>,
795 error: StreamError,
796) {
797 let _ = send_error_sender.send(error.clone());
798 let _ = receive_sender.try_send(DemandResponse::Error(error));
799}
800
801fn report_tls_write_error(send_error_sender: &std_mpsc::Sender<StreamError>, error: StreamError) {
802 let _ = send_error_sender.send(error);
803}
804
805fn tls_receive_buffer_overflow() -> StreamError {
806 StreamError::Failed("TLS receive buffer filled without downstream demand".to_owned())
807}
808
809fn single_use_tls_sink_from_carrier(carrier: TlsCarrier, batch_size: usize) -> TlsByteSink {
810 let carrier = Arc::new(Mutex::new(Some(carrier)));
811 Flow::<Vec<u8>, Vec<u8>>::identity()
812 .map_with_resource(
813 {
814 let carrier = Arc::clone(&carrier);
815 move || {
816 let carrier = carrier
817 .lock()
818 .expect("single-use TLS carrier poisoned")
819 .take()
820 .ok_or_else(|| {
821 StreamError::Failed("TLS sink already materialized".into())
822 })?;
823 Ok(SendResource {
824 carrier,
825 pending: Vec::with_capacity(batch_size),
826 batch_size,
827 })
828 }
829 },
830 |resource, chunk| {
831 send_tls_chunk(resource, chunk)?;
832 Ok(NotUsed)
833 },
834 close_tls_send_resource,
835 )
836 .to_mat(Sink::ignore(), Keep::right)
837}
838
839fn close_tls_send_resource(mut resource: SendResource) -> StreamResult<Option<NotUsed>> {
840 flush_tls_send_resource(&mut resource)?;
841 resource.carrier.close_write()?;
842 Ok(None)
843}
844
845fn send_tls_chunk(resource: &mut SendResource, chunk: Vec<u8>) -> StreamResult<()> {
846 if resource.batch_size <= 1 {
847 return resource.carrier.send_one(chunk);
848 }
849 resource.pending.push(chunk);
850 if resource.pending.len() >= resource.batch_size {
851 flush_tls_send_resource(resource)?;
852 }
853 Ok(())
854}
855
856fn flush_tls_send_resource(resource: &mut SendResource) -> StreamResult<()> {
857 if resource.pending.is_empty() {
858 return resource.carrier.check_send_error();
859 }
860 let pending = std::mem::take(&mut resource.pending);
861 resource.carrier.send_items(pending)
862}
863
864fn tls_bind_source(
865 listener: TcpListener,
866 server_config: Arc<rustls::ServerConfig>,
867 local_addr: SocketAddr,
868 handle: Handle,
869 chunk_size: usize,
870) -> Source<TlsIncomingConnection, TlsBinding> {
871 let listener = Arc::new(Mutex::new(Some(listener)));
872 Source::unfold_resource(
873 {
874 let listener = Arc::clone(&listener);
875 let handle = handle.clone();
876 move || {
877 let listener = listener
878 .lock()
879 .expect("single-use TLS listener poisoned")
880 .take()
881 .ok_or_else(|| {
882 StreamError::Failed("TLS listener already materialized".into())
883 })?;
884 let (demand_sender, demand_receiver) = mpsc::channel(1);
885 let (cancel_sender, cancel_receiver) = watch::channel(false);
886 let task = handle.spawn(run_tls_bind_task(
887 listener,
888 Arc::clone(&server_config),
889 local_addr,
890 chunk_size,
891 handle.clone(),
892 demand_receiver,
893 cancel_receiver,
894 ));
895 Ok(BindResource {
896 demands: demand_sender,
897 cancel: cancel_sender,
898 task,
899 })
900 }
901 },
902 |resource| {
903 let (reply_sender, reply_receiver) = std_mpsc::channel();
904 resource
905 .demands
906 .blocking_send(reply_sender)
907 .map_err(|_| abrupt_termination())?;
908 match reply_receiver.recv() {
909 Ok(DemandResponse::Item(connection)) => Ok(Some(connection)),
910 Ok(DemandResponse::Complete) => Ok(None),
911 Ok(DemandResponse::Error(error)) => Err(error),
912 Err(_) => Err(abrupt_termination()),
913 }
914 },
915 close_bind_resource,
916 )
917 .map_materialized_value(move |_| TlsBinding { local_addr })
918}
919
920fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
921 let _ = resource.cancel.send(true);
922 resource.task.abort();
923 Ok(())
924}
925
926async fn run_tls_bind_task(
927 listener: TcpListener,
928 server_config: Arc<rustls::ServerConfig>,
929 local_addr: SocketAddr,
930 chunk_size: usize,
931 handle: Handle,
932 mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
933 mut cancel: watch::Receiver<bool>,
934) {
935 let acceptor = TlsAcceptor::from(server_config);
936 loop {
937 let reply = tokio::select! {
938 demand = demands.recv() => match demand {
939 Some(reply) => reply,
940 None => return,
941 },
942 changed = cancel.changed() => {
943 let _ = changed;
944 return;
945 }
946 };
947
948 let (tcp, remote_addr) = loop {
949 let accepted = tokio::select! {
950 accepted = listener.accept() => accepted,
951 changed = cancel.changed() => {
952 let _ = changed;
953 return;
954 }
955 };
956
957 match accepted {
958 Ok(accepted) => break accepted,
959 Err(error) if is_transient_accept_error(&error) => continue,
960 Err(error) => {
961 let _ = reply.send(DemandResponse::Error(io_error(error)));
962 return;
963 }
964 }
965 };
966
967 let connection = TlsConnection {
968 local_addr: tcp.local_addr().unwrap_or(local_addr),
969 remote_addr,
970 };
971 let execution = tls_connection_execution(handle.clone());
972 let accepted = tokio::select! {
973 accepted = accept_tls_on_execution(tcp, acceptor.clone(), &execution) => accepted,
974 changed = cancel.changed() => {
975 let _ = changed;
976 return;
977 }
978 };
979
980 match accepted {
981 Ok(stream) => {
982 let incoming = tls_incoming_connection(stream, connection, execution, chunk_size);
983 if reply.send(DemandResponse::Item(incoming)).is_err() {
984 return;
985 }
986 }
987 Err(error) => {
988 let _ = reply.send(DemandResponse::Error(error));
989 return;
990 }
991 }
992 }
993}
994
995fn is_transient_accept_error(error: &std::io::Error) -> bool {
996 matches!(
997 error.kind(),
998 std::io::ErrorKind::Interrupted
999 | std::io::ErrorKind::ConnectionAborted
1000 | std::io::ErrorKind::ConnectionReset
1001 ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
1002}
1003
1004#[cfg(target_os = "linux")]
1005fn is_transient_accept_errno(code: i32) -> bool {
1006 matches!(code, 4 | 103 | 104)
1007}
1008
1009#[cfg(not(target_os = "linux"))]
1010fn is_transient_accept_errno(_code: i32) -> bool {
1011 false
1012}
1013
1014pub(crate) fn tls_connection_execution(
1015 fallback: Handle,
1016) -> async_carrier::ShardedTokioCarrierExecution {
1017 async_carrier::sharded_tokio_carrier_execution(fallback, &ACTIVE_TLS_CONNECTIONS)
1018}
1019
1020pub(crate) async fn tls_client_connect<A>(
1021 addr: A,
1022 server_name: ServerName<'static>,
1023 client_config: Arc<rustls::ClientConfig>,
1024 fallback: Handle,
1025 chunk_size: usize,
1026) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
1027where
1028 A: ToSocketAddrs + Send + 'static,
1029{
1030 let execution = tls_connection_execution(fallback);
1031 let (tls, connection) = execution
1032 .run(async move {
1033 let tcp = TcpStream::connect(addr).await.map_err(io_error)?;
1034 let connection = TlsConnection {
1035 local_addr: tcp.local_addr().map_err(io_error)?,
1036 remote_addr: tcp.peer_addr().map_err(io_error)?,
1037 };
1038 let tls = TlsConnector::from(client_config)
1039 .connect(server_name, tcp)
1040 .await
1041 .map_err(io_error)?;
1042 Ok((tls, connection))
1043 })
1044 .await?;
1045 Ok(tls_flow_from_stream_with_execution(
1046 tls, connection, execution, chunk_size,
1047 ))
1048}
1049
1050async fn accept_tls_on_execution(
1051 tcp: TcpStream,
1052 acceptor: TlsAcceptor,
1053 execution: &async_carrier::ShardedTokioCarrierExecution,
1054) -> StreamResult<tokio_rustls::server::TlsStream<TcpStream>> {
1055 enum AcceptedTcp {
1056 Tokio(TcpStream),
1057 Std(std::net::TcpStream),
1058 }
1059
1060 let tcp = if execution.is_sharded() {
1061 AcceptedTcp::Std(tcp.into_std().map_err(io_error)?)
1062 } else {
1063 AcceptedTcp::Tokio(tcp)
1064 };
1065 execution
1066 .run(async move {
1067 let tcp = match tcp {
1068 AcceptedTcp::Std(std_tcp) => TcpStream::from_std(std_tcp).map_err(io_error)?,
1069 AcceptedTcp::Tokio(tcp) => tcp,
1070 };
1071 acceptor.accept(tcp).await.map_err(io_error)
1072 })
1073 .await
1074}