1use std::{fmt::Debug, sync::Arc};
2
3use async_trait::async_trait;
4use bytes::{Bytes, BytesMut};
5use memchr::memmem::Finder;
6use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
7use tokio::{
8 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
9 net::{TcpStream, ToSocketAddrs},
10};
11use tokio_rustls::{
12 client::TlsStream,
13 rustls::{ClientConfig, RootCertStore},
14 TlsConnector,
15};
16
17use crate::{message::MARKER, Error};
18
19use super::{RecvHandle, SendHandle, Transport};
20
21#[derive(Debug)]
22pub struct Tls {
23 stream: TlsStream<TcpStream>,
24}
25
26impl Tls {
27 #[tracing::instrument(skip_all, level = "debug")]
28 pub(crate) async fn connect<A, S>(
29 addr: A,
30 server_name: S,
31 ca_cert: CertificateDer<'_>,
32 client_cert: CertificateDer<'static>,
33 client_key: PrivateKeyDer<'static>,
34 ) -> Result<Self, Error>
35 where
36 A: ToSocketAddrs + Debug + Send,
37 S: TryInto<ServerName<'static>> + Debug + Send,
38 Error: From<S::Error>,
39 {
40 let root_store = {
41 let mut store = RootCertStore::empty();
42 store.add(ca_cert)?;
43 store
44 };
45 let config = Arc::new(
46 ClientConfig::builder()
47 .with_root_certificates(root_store)
48 .with_client_auth_cert(vec![client_cert], client_key)?,
49 );
50 tracing::debug!(?config);
51 let domain = server_name.try_into()?;
52 let tcp_stream = TcpStream::connect(addr).await?;
53 let stream = TlsConnector::from(config)
54 .connect(domain, tcp_stream)
55 .await?;
56 Ok(Self { stream })
57 }
58}
59
60impl Transport for Tls {
61 type SendHandle = Sender;
62 type RecvHandle = Receiver;
63
64 #[tracing::instrument(level = "debug")]
65 fn split(self) -> (Self::SendHandle, Self::RecvHandle) {
66 let (read, write) = tokio::io::split(self.stream);
67 (Sender::new(write), Receiver::new(read))
68 }
69}
70
71#[derive(Debug)]
72pub struct Sender {
73 write: WriteHalf<TlsStream<TcpStream>>,
74}
75
76impl Sender {
77 const fn new(write: WriteHalf<TlsStream<TcpStream>>) -> Self {
78 Self { write }
79 }
80}
81
82#[async_trait]
83impl SendHandle for Sender {
84 #[tracing::instrument(level = "debug")]
85 async fn send(&mut self, data: Bytes) -> Result<(), Error> {
86 self.write.write_all(&data).await?;
87 self.write.flush().await?;
88 Ok(())
89 }
90}
91
92#[derive(Debug)]
93pub struct Receiver {
94 read: ReadHalf<TlsStream<TcpStream>>,
95 buf: BytesMut,
96 finder: Finder<'static>,
97}
98
99impl Receiver {
100 fn new(read: ReadHalf<TlsStream<TcpStream>>) -> Self {
101 let buf = BytesMut::with_capacity(1 << 10);
102 let finder = Finder::new(MARKER);
103 Self { read, buf, finder }
104 }
105}
106
107#[async_trait]
108impl RecvHandle for Receiver {
109 #[tracing::instrument(skip(self), level = "debug")]
110 async fn recv(&mut self) -> Result<Bytes, Error> {
111 let mut searched = 0;
114 loop {
115 tracing::trace!(?self.buf, "searching for message-break marker");
116 if let Some(index) = self.finder.find(&self.buf[searched..]) {
117 let end = searched + index + MARKER.len();
118 tracing::debug!("splitting {end} bytes from read buffer");
119 let message = self.buf.split_to(end).freeze();
120 tracing::trace!(?message);
121 break Ok(message);
122 }
123 searched = self.buf.len();
124 tracing::trace!("trying to read from transport");
125 let len = self.read.read_buf(&mut self.buf).await?;
126 tracing::trace!("read {len} bytes. buffer length is {}", self.buf.len());
127 }
128 }
129}