multistream/
server.rs

1use anyhow::Result;
2use easy_tokio_rustls::{TlsListener, TlsServer, TlsStream};
3use tokio::{
4    io::{AsyncReadExt, AsyncWriteExt},
5    net::{TcpListener, TcpStream, UnixListener, UnixStream},
6};
7
8use crate::DEFAULT_BUFFER_SIZE;
9
10/// Represents a server! Get ready for all those new connections
11pub struct Server {
12    /// The address being listened to
13    pub address: String,
14    /// Provides access to the underlying stream, so you can work beyond the
15    /// simple abstraction
16    pub listener: StreamListener,
17}
18
19impl Server {
20    /// Create a new server, will be listening and ready to go!
21    /// If a `cert_and_key` is provided, TLS will be used
22    /// If address begins with unix://, Unix socket will be used
23    /// Otherwise, TCP socket type will be assumed
24    pub async fn listen<T>(address: T, cert_and_key: Option<CertAndKeyFilePaths>) -> Result<Server>
25    where
26        T: ToString,
27    {
28        use StreamListener::*;
29        let address = address.to_string();
30        let listener = match &address {
31            path if path.starts_with("unix://") => Unix(UnixListener::bind(&path[7..])?),
32            // A nice default
33            id_like_to_think_tls if cert_and_key.is_some() => {
34                match id_like_to_think_tls.contains("://") {
35                    true => {
36                        let addr = id_like_to_think_tls.split_once("://").unwrap().1;
37                        let cert_file = &cert_and_key.as_ref().unwrap().cert;
38                        let key_file = &cert_and_key.as_ref().unwrap().key;
39                        let server = TlsServer::new(addr, cert_file, key_file).await?;
40                        Tls(server.listen().await?)
41                    }
42                    false => {
43                        let cert_file = &cert_and_key.as_ref().unwrap().cert;
44                        let key_file = &cert_and_key.as_ref().unwrap().key;
45                        let server =
46                            TlsServer::new(id_like_to_think_tls, cert_file, key_file).await?;
47                        Tls(server.listen().await?)
48                    }
49                }
50            }
51            fine_assumed_tcp => match fine_assumed_tcp.contains("://") {
52                true => {
53                    let addr = fine_assumed_tcp.split_once("://").unwrap().1;
54                    Tcp(TcpListener::bind(addr).await?)
55                }
56                _ => Tcp(TcpListener::bind(fine_assumed_tcp).await?),
57            },
58        };
59        let server = Server { address, listener };
60        Ok(server)
61    }
62
63    /// Accept connection from a new client
64    pub async fn accept(&mut self) -> Result<StreamClient> {
65        let (stream, address) = match &self.listener {
66            StreamListener::Tcp(listener) => {
67                let (stream, address) = listener.accept().await?;
68                (ClientStream::Tcp(stream), address.to_string())
69            }
70            StreamListener::Tls(listener) => {
71                let (tcp_stream, address) = listener.stream_accept().await?;
72                let stream = tcp_stream.tls_accept().await?;
73                (ClientStream::Tls(Box::new(stream)), address.to_string())
74            }
75            StreamListener::Unix(listener) => {
76                let (stream, _) = listener.accept().await?;
77                (ClientStream::Unix(stream), self.address.clone())
78            }
79        };
80        Ok(StreamClient { address, stream })
81    }
82}
83
84/// This is the underlying listener handle for the server
85pub enum StreamListener {
86    Tcp(TcpListener),
87    Tls(TlsListener),
88    Unix(UnixListener),
89}
90
91/// This structure represents a connected client
92pub struct StreamClient {
93    pub address: String,
94    pub stream: ClientStream,
95}
96
97impl StreamClient {
98    /// Sends the provided buffer to the connected client
99    pub async fn send(&mut self, data: &[u8]) -> Result<()> {
100        use ClientStream::*;
101        match &mut self.stream {
102            Tcp(stream) => {
103                stream.write_all(data).await?;
104            }
105            Tls(stream) => {
106                stream.write_all(data).await?;
107            }
108            Unix(stream) => {
109                stream.write_all(data).await?;
110            }
111        };
112        Ok(())
113    }
114
115    /// Receives data from the connected client
116    pub async fn recv(&mut self) -> Result<Vec<u8>> {
117        use ClientStream::*;
118
119        let mut buffer = [0; DEFAULT_BUFFER_SIZE];
120        let size = match &mut self.stream {
121            Tcp(stream) => stream.read(&mut buffer).await?,
122            Tls(stream) => stream.read(&mut buffer).await?,
123            Unix(stream) => stream.read(&mut buffer).await?,
124        };
125        Ok(buffer[0..size].to_vec())
126    }
127}
128
129/// Holds the underlying stream handle for a connected client
130pub enum ClientStream {
131    /// Handle for a TCP connected client
132    Tcp(TcpStream),
133    /// Handle for a TLS connected client
134    Tls(Box<TlsStream<TcpStream>>),
135    /// Handle for a Unix socket connected client
136    Unix(UnixStream),
137}
138
139/// Simple structure that holds file paths to the TLS certificate and key
140pub struct CertAndKeyFilePaths {
141    /// Path to TLS certificate file
142    pub cert: String,
143    /// Path to TLS key file
144    pub key: String,
145}
146
147impl CertAndKeyFilePaths {
148    /// Creates a new Cert/Key file path pair
149    pub fn new<T, U>(cert: T, key: U) -> Self
150    where
151        T: ToString,
152        U: ToString,
153    {
154        CertAndKeyFilePaths {
155            cert: cert.to_string(),
156            key: key.to_string(),
157        }
158    }
159}