1use std::{io, sync::Arc};
2
3use ombrac::prelude::*;
4use ombrac_transport::{Acceptor, Reliable};
5
6#[cfg(feature = "datagram")]
7use ombrac_transport::Unreliable;
8
9use ombrac_macros::error;
10
11pub struct Server<T> {
12 secret: Secret,
13 transport: T,
14}
15
16impl<T: Acceptor> Server<T> {
17 pub fn new(secret: Secret, transport: T) -> Self {
18 Self { secret, transport }
19 }
20
21 async fn handle_reliable(stream: impl Reliable, secret: Secret) -> io::Result<()> {
22 Self::handle_tcp_connect(stream, secret).await
23 }
24
25 #[cfg(feature = "datagram")]
26 async fn handle_unreliable(stream: impl Unreliable, secret: Secret) -> io::Result<()> {
27 Self::handle_udp_associate(stream, secret).await
28 }
29
30 #[inline]
31 async fn handle_tcp_connect(mut stream: impl Reliable, secret: Secret) -> io::Result<()> {
32 use tokio::net::TcpStream;
33
34 let request = Connect::from_async_read(&mut stream).await?;
35
36 if request.secret != secret {
37 return Err(io::Error::new(
38 io::ErrorKind::PermissionDenied,
39 "Secret does not match",
40 ));
41 }
42
43 let addr = request.address.to_socket_addr().await?;
44 let mut target = TcpStream::connect(addr).await?;
45
46 ombrac::io::util::copy_bidirectional(&mut stream, &mut target).await?;
47
48 Ok(())
49 }
50
51 #[cfg(feature = "datagram")]
52 #[inline]
53 async fn handle_udp_associate(conn: impl Unreliable, secret: Secret) -> io::Result<()> {
54 use std::net::SocketAddr;
55 use tokio::net::UdpSocket;
56 use tokio::time::{timeout, Duration};
57
58 const DEFAULT_BUFFER_SIZE: usize = 2 * 1024;
59 const RECV_TIMEOUT: Duration = Duration::from_secs(180);
60
61 let local = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0));
62 let socket = UdpSocket::bind(local).await?;
63 let sock_send = Arc::new(socket);
64 let sock_recv = Arc::clone(&sock_send);
65 let conn_send = Arc::new(conn);
66 let conn_recv = Arc::clone(&conn_send);
67
68 let mut recv_handle = tokio::spawn(async move {
69 let mut buf = [0u8; DEFAULT_BUFFER_SIZE];
70 loop {
71 let (len, addr) = sock_recv.recv_from(&mut buf).await?;
72 let data = bytes::Bytes::copy_from_slice(&buf[..len]);
73 let packet = Packet::with(secret, addr, data);
74 if let Err(e) = conn_send.send(packet.to_bytes()?).await {
75 return Err(io::Error::other(e.to_string()));
76 }
77 }
78 });
79
80 let mut send_handle = tokio::spawn(async move {
81 loop {
82 let packet_result = timeout(RECV_TIMEOUT, conn_recv.recv()).await;
83
84 let result = match packet_result {
85 Ok(value) => value,
86 Err(_) => return Ok(()), };
88
89 match result {
90 Ok(mut packet) => {
91 let packet = Packet::from_bytes(&mut packet)?;
92 if packet.secret != secret {
93 return Err(io::Error::new(
94 io::ErrorKind::PermissionDenied,
95 "Secret does not match",
96 ));
97 };
98 let target = packet.address.to_socket_addr().await?;
99 sock_send.send_to(&packet.data, target).await?;
100 }
101 Err(e) => {
102 return Err(io::Error::other(e.to_string()));
103 }
104 }
105 }
106 });
107
108 let result = tokio::select! {
109 result = &mut recv_handle => {
110 send_handle.abort();
111 result
112 },
113 result = &mut send_handle => {
114 recv_handle.abort();
115 result
116 },
117 };
118
119 match result {
120 Ok(inner_result) => inner_result,
121 Err(e) if e.is_cancelled() => Ok(()),
122 Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
123 }
124 }
125
126 pub async fn listen(self) -> io::Result<()> {
127 let secret = self.secret.clone();
128
129 let transport = Arc::new(self.transport);
130
131 #[cfg(feature = "datagram")]
132 let datagram_handle = {
133 let transport = transport.clone();
134 let datagram_handle = tokio::spawn(async move {
135 loop {
136 match transport.accept_datagram().await {
137 Ok(stream) => {
138 tokio::spawn(async move {
139 if let Err(_error) = Self::handle_unreliable(stream, secret).await {
140 error!("{_error}");
141 }
142 });
143 }
144 Err(_error) => {
145 error!("{_error}");
146
147 break;
148 }
149 };
150 }
151 });
152
153 datagram_handle
154 };
155
156 loop {
157 match transport.accept_bidirectional().await {
158 Ok(stream) => tokio::spawn(async move {
159 if let Err(_error) = Self::handle_reliable(stream, secret).await {
160 error!("{_error}");
161 }
162 }),
163 Err(err) => {
164 #[cfg(feature = "datagram")]
165 datagram_handle.abort();
166
167 return Err(io::Error::other(err.to_string()));
168 }
169 };
170 }
171 }
172}