netconf/transport/
tls.rs

1use std::{fmt::Debug, sync::Arc};
2
3use async_trait::async_trait;
4use bytes::{Bytes, BytesMut};
5use memchr::memmem::Finder;
6use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
7use tokio::{
8    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
9    net::{TcpStream, ToSocketAddrs},
10};
11use tokio_rustls::{
12    client::TlsStream,
13    rustls::{ClientConfig, RootCertStore},
14    TlsConnector,
15};
16
17use crate::{message::MARKER, Error};
18
19use super::{RecvHandle, SendHandle, Transport};
20
21#[derive(Debug)]
22pub struct Tls {
23    stream: TlsStream<TcpStream>,
24}
25
26impl Tls {
27    #[tracing::instrument(skip_all, level = "debug")]
28    pub(crate) async fn connect<A, S>(
29        addr: A,
30        server_name: S,
31        ca_cert: CertificateDer<'_>,
32        client_cert: CertificateDer<'static>,
33        client_key: PrivateKeyDer<'static>,
34    ) -> Result<Self, Error>
35    where
36        A: ToSocketAddrs + Debug + Send,
37        S: TryInto<ServerName<'static>> + Debug + Send,
38        Error: From<S::Error>,
39    {
40        let root_store = {
41            let mut store = RootCertStore::empty();
42            store.add(ca_cert)?;
43            store
44        };
45        let config = Arc::new(
46            ClientConfig::builder()
47                .with_root_certificates(root_store)
48                .with_client_auth_cert(vec![client_cert], client_key)?,
49        );
50        tracing::debug!(?config);
51        let domain = server_name.try_into()?;
52        let tcp_stream = TcpStream::connect(addr).await?;
53        let stream = TlsConnector::from(config)
54            .connect(domain, tcp_stream)
55            .await?;
56        Ok(Self { stream })
57    }
58}
59
60impl Transport for Tls {
61    type SendHandle = Sender;
62    type RecvHandle = Receiver;
63
64    #[tracing::instrument(level = "debug")]
65    fn split(self) -> (Self::SendHandle, Self::RecvHandle) {
66        let (read, write) = tokio::io::split(self.stream);
67        (Sender::new(write), Receiver::new(read))
68    }
69}
70
71#[derive(Debug)]
72pub struct Sender {
73    write: WriteHalf<TlsStream<TcpStream>>,
74}
75
76impl Sender {
77    const fn new(write: WriteHalf<TlsStream<TcpStream>>) -> Self {
78        Self { write }
79    }
80}
81
82#[async_trait]
83impl SendHandle for Sender {
84    #[tracing::instrument(level = "debug")]
85    async fn send(&mut self, data: Bytes) -> Result<(), Error> {
86        self.write.write_all(&data).await?;
87        self.write.flush().await?;
88        Ok(())
89    }
90}
91
92#[derive(Debug)]
93pub struct Receiver {
94    read: ReadHalf<TlsStream<TcpStream>>,
95    buf: BytesMut,
96    finder: Finder<'static>,
97}
98
99impl Receiver {
100    fn new(read: ReadHalf<TlsStream<TcpStream>>) -> Self {
101        let buf = BytesMut::with_capacity(1 << 10);
102        let finder = Finder::new(MARKER);
103        Self { read, buf, finder }
104    }
105}
106
107#[async_trait]
108impl RecvHandle for Receiver {
109    #[tracing::instrument(skip(self), level = "debug")]
110    async fn recv(&mut self) -> Result<Bytes, Error> {
111        // TODO:
112        // handle case when read ends part way through an end marker
113        let mut searched = 0;
114        loop {
115            tracing::trace!(?self.buf, "searching for message-break marker");
116            if let Some(index) = self.finder.find(&self.buf[searched..]) {
117                let end = searched + index + MARKER.len();
118                tracing::debug!("splitting {end} bytes from read buffer");
119                let message = self.buf.split_to(end).freeze();
120                tracing::trace!(?message);
121                break Ok(message);
122            }
123            searched = self.buf.len();
124            tracing::trace!("trying to read from transport");
125            let len = self.read.read_buf(&mut self.buf).await?;
126            tracing::trace!("read {len} bytes. buffer length is {}", self.buf.len());
127        }
128    }
129}