ombrac_server/
server.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
use std::io;
use std::net::SocketAddr;

use ombrac::io::Streamable;
use ombrac::request::{Address, Request};
use ombrac::Provider;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;

use ombrac_macros::error;

pub struct Server<T> {
    transport: T,
}

impl<Transport, Stream> Server<Transport>
where
    Transport: Provider<Item = Stream>,
    Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
    pub fn new(transport: Transport) -> Self {
        Self { transport }
    }

    async fn handler(mut stream: Stream) -> io::Result<()> {
        let request = Request::read(&mut stream).await?;

        match request {
            Request::TcpConnect(address) => {
                let address = Self::resolve(address).await?;
                let mut outbound = TcpStream::connect(address).await?;

                ombrac::io::util::copy_bidirectional(&mut stream, &mut outbound).await?
            }
        };

        Ok(())
    }

    async fn resolve(address: Address) -> io::Result<SocketAddr> {
        use crate::dns::lookup_ip;

        let result = match address {
            Address::Domain(domain, port) => SocketAddr::new(lookup_ip(&domain).await?, port),
            Address::IPv4(addr) => SocketAddr::V4(addr),
            Address::IPv6(addr) => SocketAddr::V6(addr),
        };

        Ok(result)
    }

    pub async fn listen(&mut self) -> io::Result<()> {
        while let Some(stream) = self.transport.fetch().await {
            tokio::spawn(async move {
                if let Err(_error) = Self::handler(stream).await {
                    error!("{}", _error);
                }
            });
        }

        Ok(())
    }
}