ombrac_server/
server.rs

1use std::{io, sync::Arc};
2
3use ombrac::prelude::*;
4use ombrac_transport::{Reliable, Transport};
5
6#[cfg(feature = "datagram")]
7use ombrac_transport::Unreliable;
8
9use ombrac_macros::error;
10
11pub struct Server<T> {
12    secret: Secret,
13    transport: T,
14}
15
16impl<T: Transport> Server<T> {
17    pub fn new(secret: Secret, transport: T) -> Self {
18        Self { secret, transport }
19    }
20
21    async fn handle_reliable(stream: impl Reliable, secret: Secret) -> io::Result<()> {
22        Self::handle_tcp_connect(stream, secret).await
23    }
24
25    #[cfg(feature = "datagram")]
26    async fn handle_unreliable(stream: impl Unreliable, secret: Secret) -> io::Result<()> {
27        Self::handle_udp_associate(stream, secret).await
28    }
29
30    #[inline]
31    async fn handle_tcp_connect(mut stream: impl Reliable, secret: Secret) -> io::Result<()> {
32        use tokio::net::TcpStream;
33
34        let request = Connect::from_async_read(&mut stream).await?;
35
36        if request.secret != secret {
37            return Err(io::Error::new(
38                io::ErrorKind::PermissionDenied,
39                "Secret does not match",
40            ));
41        }
42
43        let addr = request.address.to_socket_addr().await?;
44        let mut target = TcpStream::connect(addr).await?;
45
46        ombrac::io::util::copy_bidirectional(&mut stream, &mut target).await?;
47
48        Ok(())
49    }
50
51    #[cfg(feature = "datagram")]
52    #[inline]
53    async fn handle_udp_associate(conn: impl Unreliable, secret: Secret) -> io::Result<()> {
54        use std::net::SocketAddr;
55        use tokio::net::UdpSocket;
56
57        const DEFAULT_BUFFER_SIZE: usize = 2 * 1024;
58
59        let local = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0));
60        let socket = UdpSocket::bind(local).await?;
61
62        let sock_send = Arc::new(socket);
63        let sock_recv = Arc::clone(&sock_send);
64        let conn_send = Arc::new(conn);
65        let conn_recv = Arc::clone(&conn_send);
66
67        let handle = tokio::spawn(async move {
68            let mut buf = [0u8; DEFAULT_BUFFER_SIZE];
69
70            loop {
71                let (len, addr) = sock_recv.recv_from(&mut buf).await?;
72
73                let data = bytes::Bytes::copy_from_slice(&buf[..len]);
74                let packet = Packet::with(secret, addr, data);
75
76                if conn_send.send(packet.to_bytes()?).await.is_err() {
77                    break;
78                }
79            }
80
81            Ok::<(), io::Error>(())
82        });
83
84        while let Ok(mut packet) = conn_recv.recv().await {
85            let packet = Packet::from_bytes(&mut packet)?;
86
87            if packet.secret != secret {
88                return Err(io::Error::new(
89                    io::ErrorKind::PermissionDenied,
90                    "Secret does not match",
91                ));
92            };
93
94            let target = packet.address.to_socket_addr().await?;
95            sock_send.send_to(&packet.data, target).await?;
96        }
97
98        handle.abort();
99
100        Ok(())
101    }
102
103    pub async fn listen(self) -> io::Result<()> {
104        let secret = self.secret.clone();
105
106        let transport = Arc::new(self.transport);
107
108        #[cfg(feature = "datagram")]
109        {
110            let unreliable_transport = transport.clone();
111            tokio::spawn(async move {
112                match unreliable_transport.unreliable().await {
113                    Ok(stream) => {
114                        tokio::spawn(async move {
115                            if let Err(_error) = Self::handle_unreliable(stream, secret).await {
116                                error!("{_error}");
117                            }
118                        });
119                    }
120                    Err(_error) => {
121                        error!("{}", _error);
122                    }
123                };
124
125                ()
126            });
127        }
128
129        loop {
130            match transport.reliable().await {
131                Ok(stream) => tokio::spawn(async move {
132                    if let Err(_error) = Self::handle_reliable(stream, secret).await {
133                        error!("{_error}");
134                    }
135                }),
136                Err(err) => return Err(io::Error::other(err.to_string())),
137            };
138        }
139    }
140}