1use std::io;
2
3use ombrac::io::Streamable;
4use ombrac::request::{Address, Request};
5use ombrac::Provider;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpStream;
8
9use ombrac_macros::error;
10
11pub struct Server<T> {
12 secret: [u8; 32],
13 transport: T,
14}
15
16impl<Transport, Stream> Server<Transport>
17where
18 Transport: Provider<Item = Stream>,
19 Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
20{
21 pub fn new(secret: [u8; 32], transport: Transport) -> Self {
22 Self { secret, transport }
23 }
24
25 async fn handler(mut stream: Stream, secret: &[u8; 32]) -> io::Result<()> {
26 let request = Request::read(&mut stream).await?;
27
28 match request {
29 Request::TcpConnect(client_auth, addr) => {
30 if &client_auth != secret {
31 return Err(io::Error::new(
32 io::ErrorKind::PermissionDenied,
33 "Authentication failed",
34 ));
35 }
36 Self::handle_tcp_connect(stream, addr).await?
37 }
38 };
39
40 Ok(())
41 }
42
43 async fn handle_tcp_connect<A>(mut stream: Stream, addr: A) -> io::Result<Stream>
44 where
45 A: Into<Address>,
46 {
47 let addr = addr.into().to_socket_addr().await?;
48 let mut outbound = TcpStream::connect(addr).await?;
49
50 ombrac::io::util::copy_bidirectional(&mut stream, &mut outbound).await?;
51
52 Ok(stream)
53 }
54
55 pub async fn listen(&self) -> io::Result<()> {
56 let secret = self.secret.clone();
57
58 while let Some(stream) = self.transport.fetch().await {
59 tokio::spawn(async move {
60 if let Err(e) = Self::handler(stream, &secret).await {
61 error!("{}", e);
62 }
63 });
64 }
65
66 Ok(())
67 }
68}