ombrac_server/
server.rs

1use std::io;
2
3use ombrac::io::Streamable;
4use ombrac::request::{Address, Request};
5use ombrac::Provider;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpStream;
8
9use ombrac_macros::error;
10
11pub struct Server<T> {
12    secret: [u8; 32],
13    transport: T,
14}
15
16impl<Transport, Stream> Server<Transport>
17where
18    Transport: Provider<Item = Stream>,
19    Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
20{
21    pub fn new(secret: [u8; 32], transport: Transport) -> Self {
22        Self { secret, transport }
23    }
24
25    async fn handler(mut stream: Stream, secret: &[u8; 32]) -> io::Result<()> {
26        let request = Request::read(&mut stream).await?;
27
28        match request {
29            Request::TcpConnect(client_auth, addr) => {
30                if &client_auth != secret {
31                    return Err(io::Error::new(
32                        io::ErrorKind::PermissionDenied,
33                        "Authentication failed",
34                    ));
35                }
36                Self::handle_tcp_connect(stream, addr).await?
37            }
38        };
39
40        Ok(())
41    }
42
43    async fn handle_tcp_connect<A>(mut stream: Stream, addr: A) -> io::Result<Stream>
44    where
45        A: Into<Address>,
46    {
47        let addr = addr.into().to_socket_addr().await?;
48        let mut outbound = TcpStream::connect(addr).await?;
49
50        ombrac::io::util::copy_bidirectional(&mut stream, &mut outbound).await?;
51
52        Ok(stream)
53    }
54
55    pub async fn listen(&self) -> io::Result<()> {
56        let secret = self.secret.clone();
57
58        while let Some(stream) = self.transport.fetch().await {
59            tokio::spawn(async move {
60                if let Err(e) = Self::handler(stream, &secret).await {
61                    error!("{}", e);
62                }
63            });
64        }
65
66        Ok(())
67    }
68}