1use anyhow::Result;
2use easy_tokio_rustls::{TlsListener, TlsServer, TlsStream};
3use tokio::{
4 io::{AsyncReadExt, AsyncWriteExt},
5 net::{TcpListener, TcpStream, UnixListener, UnixStream},
6};
7
8use crate::DEFAULT_BUFFER_SIZE;
9
10pub struct Server {
12 pub address: String,
14 pub listener: StreamListener,
17}
18
19impl Server {
20 pub async fn listen<T>(address: T, cert_and_key: Option<CertAndKeyFilePaths>) -> Result<Server>
25 where
26 T: ToString,
27 {
28 use StreamListener::*;
29 let address = address.to_string();
30 let listener = match &address {
31 path if path.starts_with("unix://") => Unix(UnixListener::bind(&path[7..])?),
32 id_like_to_think_tls if cert_and_key.is_some() => {
34 match id_like_to_think_tls.contains("://") {
35 true => {
36 let addr = id_like_to_think_tls.split_once("://").unwrap().1;
37 let cert_file = &cert_and_key.as_ref().unwrap().cert;
38 let key_file = &cert_and_key.as_ref().unwrap().key;
39 let server = TlsServer::new(addr, cert_file, key_file).await?;
40 Tls(server.listen().await?)
41 }
42 false => {
43 let cert_file = &cert_and_key.as_ref().unwrap().cert;
44 let key_file = &cert_and_key.as_ref().unwrap().key;
45 let server =
46 TlsServer::new(id_like_to_think_tls, cert_file, key_file).await?;
47 Tls(server.listen().await?)
48 }
49 }
50 }
51 fine_assumed_tcp => match fine_assumed_tcp.contains("://") {
52 true => {
53 let addr = fine_assumed_tcp.split_once("://").unwrap().1;
54 Tcp(TcpListener::bind(addr).await?)
55 }
56 _ => Tcp(TcpListener::bind(fine_assumed_tcp).await?),
57 },
58 };
59 let server = Server { address, listener };
60 Ok(server)
61 }
62
63 pub async fn accept(&mut self) -> Result<StreamClient> {
65 let (stream, address) = match &self.listener {
66 StreamListener::Tcp(listener) => {
67 let (stream, address) = listener.accept().await?;
68 (ClientStream::Tcp(stream), address.to_string())
69 }
70 StreamListener::Tls(listener) => {
71 let (tcp_stream, address) = listener.stream_accept().await?;
72 let stream = tcp_stream.tls_accept().await?;
73 (ClientStream::Tls(Box::new(stream)), address.to_string())
74 }
75 StreamListener::Unix(listener) => {
76 let (stream, _) = listener.accept().await?;
77 (ClientStream::Unix(stream), self.address.clone())
78 }
79 };
80 Ok(StreamClient { address, stream })
81 }
82}
83
84pub enum StreamListener {
86 Tcp(TcpListener),
87 Tls(TlsListener),
88 Unix(UnixListener),
89}
90
91pub struct StreamClient {
93 pub address: String,
94 pub stream: ClientStream,
95}
96
97impl StreamClient {
98 pub async fn send(&mut self, data: &[u8]) -> Result<()> {
100 use ClientStream::*;
101 match &mut self.stream {
102 Tcp(stream) => {
103 stream.write_all(data).await?;
104 }
105 Tls(stream) => {
106 stream.write_all(data).await?;
107 }
108 Unix(stream) => {
109 stream.write_all(data).await?;
110 }
111 };
112 Ok(())
113 }
114
115 pub async fn recv(&mut self) -> Result<Vec<u8>> {
117 use ClientStream::*;
118
119 let mut buffer = [0; DEFAULT_BUFFER_SIZE];
120 let size = match &mut self.stream {
121 Tcp(stream) => stream.read(&mut buffer).await?,
122 Tls(stream) => stream.read(&mut buffer).await?,
123 Unix(stream) => stream.read(&mut buffer).await?,
124 };
125 Ok(buffer[0..size].to_vec())
126 }
127}
128
129pub enum ClientStream {
131 Tcp(TcpStream),
133 Tls(Box<TlsStream<TcpStream>>),
135 Unix(UnixStream),
137}
138
139pub struct CertAndKeyFilePaths {
141 pub cert: String,
143 pub key: String,
145}
146
147impl CertAndKeyFilePaths {
148 pub fn new<T, U>(cert: T, key: U) -> Self
150 where
151 T: ToString,
152 U: ToString,
153 {
154 CertAndKeyFilePaths {
155 cert: cert.to_string(),
156 key: key.to_string(),
157 }
158 }
159}