ombrac_server/
server.rs

1use std::{io, sync::Arc};
2
3use ombrac::prelude::*;
4use ombrac_transport::{Acceptor, Reliable};
5
6#[cfg(feature = "datagram")]
7use ombrac_transport::Unreliable;
8
9use ombrac_macros::error;
10
11pub struct Server<T> {
12    secret: Secret,
13    transport: T,
14}
15
16impl<T: Acceptor> Server<T> {
17    pub fn new(secret: Secret, transport: T) -> Self {
18        Self { secret, transport }
19    }
20
21    async fn handle_reliable(stream: impl Reliable, secret: Secret) -> io::Result<()> {
22        Self::handle_tcp_connect(stream, secret).await
23    }
24
25    #[cfg(feature = "datagram")]
26    async fn handle_unreliable(stream: impl Unreliable, secret: Secret) -> io::Result<()> {
27        Self::handle_udp_associate(stream, secret).await
28    }
29
30    #[inline]
31    async fn handle_tcp_connect(mut stream: impl Reliable, secret: Secret) -> io::Result<()> {
32        use tokio::net::TcpStream;
33
34        let request = Connect::from_async_read(&mut stream).await?;
35
36        if request.secret != secret {
37            return Err(io::Error::new(
38                io::ErrorKind::PermissionDenied,
39                "Secret does not match",
40            ));
41        }
42
43        let addr = request.address.to_socket_addr().await?;
44        let mut target = TcpStream::connect(addr).await?;
45
46        ombrac::io::util::copy_bidirectional(&mut stream, &mut target).await?;
47
48        Ok(())
49    }
50
51    #[cfg(feature = "datagram")]
52    #[inline]
53    async fn handle_udp_associate(conn: impl Unreliable, secret: Secret) -> io::Result<()> {
54        use std::net::SocketAddr;
55        use tokio::net::UdpSocket;
56        use tokio::time::{timeout, Duration};
57
58        const DEFAULT_BUFFER_SIZE: usize = 2 * 1024;
59        const RECV_TIMEOUT: Duration = Duration::from_secs(180);
60
61        let local = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0));
62        let socket = UdpSocket::bind(local).await?;
63        let sock_send = Arc::new(socket);
64        let sock_recv = Arc::clone(&sock_send);
65        let conn_send = Arc::new(conn);
66        let conn_recv = Arc::clone(&conn_send);
67
68        let mut recv_handle = tokio::spawn(async move {
69            let mut buf = [0u8; DEFAULT_BUFFER_SIZE];
70            loop {
71                let (len, addr) = sock_recv.recv_from(&mut buf).await?;
72                let data = bytes::Bytes::copy_from_slice(&buf[..len]);
73                let packet = Packet::with(secret, addr, data);
74                if let Err(e) = conn_send.send(packet.to_bytes()?).await {
75                    return Err(io::Error::other(e.to_string()));
76                }
77            }
78        });
79
80        let mut send_handle = tokio::spawn(async move {
81            loop {
82                let packet_result = timeout(RECV_TIMEOUT, conn_recv.recv()).await;
83
84                let result = match packet_result {
85                    Ok(value) => value,
86                    Err(_) => return Ok(()), // UDP recv timeout
87                };
88
89                match result {
90                    Ok(mut packet) => {
91                        let packet = Packet::from_bytes(&mut packet)?;
92                        if packet.secret != secret {
93                            return Err(io::Error::new(
94                                io::ErrorKind::PermissionDenied,
95                                "Secret does not match",
96                            ));
97                        };
98                        let target = packet.address.to_socket_addr().await?;
99                        sock_send.send_to(&packet.data, target).await?;
100                    }
101                    Err(e) => {
102                        return Err(io::Error::other(e.to_string()));
103                    }
104                }
105            }
106        });
107
108        let result = tokio::select! {
109            result = &mut recv_handle => {
110                send_handle.abort();
111                result
112            },
113            result = &mut send_handle => {
114                recv_handle.abort();
115                result
116            },
117        };
118
119        match result {
120            Ok(inner_result) => inner_result,
121            Err(e) if e.is_cancelled() => Ok(()),
122            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
123        }
124    }
125
126    pub async fn listen(self) -> io::Result<()> {
127        let secret = self.secret.clone();
128
129        let transport = Arc::new(self.transport);
130
131        #[cfg(feature = "datagram")]
132        let datagram_handle = {
133            let transport = transport.clone();
134            let datagram_handle = tokio::spawn(async move {
135                loop {
136                    match transport.accept_datagram().await {
137                        Ok(stream) => {
138                            tokio::spawn(async move {
139                                if let Err(_error) = Self::handle_unreliable(stream, secret).await {
140                                    error!("{_error}");
141                                }
142                            });
143                        }
144                        Err(_error) => {
145                            error!("{_error}");
146
147                            break;
148                        }
149                    };
150                }
151            });
152
153            datagram_handle
154        };
155
156        loop {
157            match transport.accept_bidirectional().await {
158                Ok(stream) => tokio::spawn(async move {
159                    if let Err(_error) = Self::handle_reliable(stream, secret).await {
160                        error!("{_error}");
161                    }
162                }),
163                Err(err) => {
164                    #[cfg(feature = "datagram")]
165                    datagram_handle.abort();
166
167                    return Err(io::Error::other(err.to_string()));
168                }
169            };
170        }
171    }
172}