1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::broadcast;
5use tokio::task::JoinHandle;
6
7use futures::{SinkExt, StreamExt};
8use tokio::io::AsyncWriteExt;
9use tokio::net::TcpStream;
10use tokio_util::codec::Framed;
11use tokio_util::sync::CancellationToken;
12
13use ombrac::codec::{LengthDelimitedCodec, ServerHandshakeResponse, UpstreamMessage, length_codec};
14use ombrac::protocol::{self, HandshakeError, PROTOCOLS_VERSION, Secret};
15use ombrac_macros::{debug, error, info, warn};
16use ombrac_transport::{Acceptor, Connection};
17
18#[cfg(feature = "datagram")]
20use self::datagram::DatagraContext;
21
22pub struct Server<T: Acceptor> {
24 acceptor: Arc<T>,
25 secret: Secret,
26}
27
28impl<T: Acceptor> Server<T> {
29 pub fn new(acceptor: T, secret: Secret) -> Self {
31 Self {
32 acceptor: Arc::new(acceptor),
33 secret,
34 }
35 }
36
37 pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
39 loop {
40 tokio::select! {
41 _ = shutdown_rx.recv() => {
42 return Ok(());
43 }
44 accepted = self.acceptor.accept() => {
46 match accepted {
47 Ok(connection) => {
48 let secret = self.secret;
49 let peer_addr = connection.remote_address().unwrap_or_else(|_| "unknown".parse().unwrap());
50 info!("{} Connection established", peer_addr);
51
52 tokio::spawn(async move {
54 if let Err(e) = ConnectionHandler::handle(connection, secret, peer_addr).await {
55 if e.kind() != io::ErrorKind::ConnectionReset && e.kind() != io::ErrorKind::BrokenPipe && e.kind() != io::ErrorKind::UnexpectedEof {
56 error!("{} Connection handler failed: {}", peer_addr, e);
57 } else {
58 info!("{} Connection closed by peer", peer_addr);
59 }
60 }
61 });
62 },
63 Err(_e) => {
64 error!("Failed to accept connection: {}", _e)
65 },
66 }
67 },
68 }
69 }
70 }
71
72 pub fn local_addr(&self) -> io::Result<SocketAddr> {
73 self.acceptor.local_addr()
74 }
75}
76
77struct ConnectionHandler<C: Connection> {
82 connection: Arc<C>,
83 peer_addr: SocketAddr,
84 cancellation_token: CancellationToken,
85}
86
87impl<C: Connection> ConnectionHandler<C> {
88 pub async fn handle(connection: C, secret: Secret, peer_addr: SocketAddr) -> io::Result<()> {
90 let mut control_stream = connection.accept_bidirectional().await?;
92 let mut framed_control = Framed::new(&mut control_stream, length_codec());
93
94 match framed_control.next().await {
95 Some(Ok(payload)) => {
96 let hello_message: UpstreamMessage = protocol::decode(&payload)?;
97 Self::validate_handshake(hello_message, secret, peer_addr, &mut framed_control)
98 .await?;
99 }
100 _ => {
101 return Err(io::Error::new(
102 io::ErrorKind::InvalidData,
103 "Failed to read Hello message",
104 ));
105 }
106 }
107
108 let handler = Self {
110 connection: Arc::new(connection),
111 peer_addr,
112 cancellation_token: CancellationToken::new(),
113 };
114
115 handler.run_proxy_tasks().await;
116 Ok(())
117 }
118
119 async fn validate_handshake(
121 message: UpstreamMessage,
122 secret: Secret,
123 _peer_addr: SocketAddr,
124 framed: &mut Framed<&mut C::Stream, LengthDelimitedCodec>,
125 ) -> io::Result<()> {
126 if let UpstreamMessage::Hello(hello) = message {
127 let response = if hello.version != PROTOCOLS_VERSION {
128 warn!(
129 "{} Handshake failed: Unsupported protocol version",
130 _peer_addr
131 );
132 ServerHandshakeResponse::Err(HandshakeError::UnsupportedVersion)
133 } else if hello.secret != secret {
134 warn!("{} Handshake failed: Invalid secret", _peer_addr);
135 ServerHandshakeResponse::Err(HandshakeError::InvalidSecret)
136 } else {
137 debug!("{} Handshake successful", _peer_addr);
138 ServerHandshakeResponse::Ok
139 };
140
141 let response_bytes = protocol::encode(&response)?;
142 framed.send(response_bytes).await?;
143
144 if matches!(response, ServerHandshakeResponse::Err(_)) {
145 return Err(io::Error::new(
146 io::ErrorKind::PermissionDenied,
147 "Handshake failed",
148 ));
149 }
150 Ok(())
151 } else {
152 Err(io::Error::new(
153 io::ErrorKind::InvalidData,
154 "Expected Hello message",
155 ))
156 }
157 }
158
159 async fn run_proxy_tasks(&self) {
165 let tcp_handler = self.spawn_tcp_handler();
166
167 #[cfg(feature = "datagram")]
168 let udp_handler = self.spawn_udp_handler();
169
170 #[cfg(feature = "datagram")]
172 let result = tokio::select! {
173 res = tcp_handler => res,
174 res = udp_handler => res,
175 };
176
177 #[cfg(not(feature = "datagram"))]
179 let result = tcp_handler.await;
180
181 self.cancellation_token.cancel(); match result {
184 Ok(Ok(_)) => {
185 debug!("{} Client connection closed gracefully.", self.peer_addr);
186 }
187 Ok(Err(e)) => {
188 if e.kind() != io::ErrorKind::ConnectionAborted {
189 warn!(
190 "{} Client connection closed with an error: {}",
191 self.peer_addr, e
192 );
193 }
194 }
195 Err(_join_err) => {
196 warn!(
197 "{} Client connection handler task failed: {}",
198 self.peer_addr, _join_err
199 );
200 }
201 }
202 }
203
204 fn spawn_tcp_handler(&self) -> JoinHandle<io::Result<()>> {
206 let connection = Arc::clone(&self.connection);
207 let peer_addr = self.peer_addr;
208 let token = self.cancellation_token.child_token();
209
210 tokio::spawn(async move {
211 loop {
212 tokio::select! {
213 _ = token.cancelled() => return Ok(()),
214 result = connection.accept_bidirectional() => {
215 let stream = result?;
216
217 tokio::spawn(async move {
219 if let Err(_e) = Self::handle_tcp_stream(stream, peer_addr).await {
220 warn!("{} Stream handler error: {}", peer_addr, _e);
221 }
222 });
223 }
224 }
225 }
226 })
227 }
228
229 async fn handle_tcp_stream(mut stream: C::Stream, _peer_addr: SocketAddr) -> io::Result<()> {
231 let mut framed = Framed::new(&mut stream, length_codec());
232
233 let original_dest = match framed.next().await {
235 Some(Ok(payload)) => match protocol::decode(&payload)? {
236 UpstreamMessage::Connect(connect) => connect.address,
237 _ => {
238 return Err(io::Error::new(
239 io::ErrorKind::InvalidData,
240 "Expected Connect message",
241 ));
242 }
243 },
244 _ => {
245 return Err(io::Error::new(
246 io::ErrorKind::InvalidData,
247 "Failed to read Connect message on new stream",
248 ));
249 }
250 };
251
252 let mut dest_stream = TcpStream::connect(original_dest.to_socket_addr().await?).await?;
253
254 let parts = framed.into_parts();
256 let mut stream = parts.io;
257 if !parts.read_buf.is_empty() {
258 dest_stream.write_all(&parts.read_buf).await?;
259 }
260
261 match ombrac_transport::io::copy_bidirectional(&mut stream, &mut dest_stream).await {
263 Ok(_stats) => {
264 #[cfg(feature = "tracing")]
265 tracing::info!(
266 src_addr = _peer_addr.to_string(),
267 dst_addr = original_dest.to_string(),
268 send = _stats.a_to_b_bytes,
269 recv = _stats.b_to_a_bytes,
270 status = "ok",
271 "Connect"
272 );
273 }
274 Err((err, _stats)) => {
275 #[cfg(feature = "tracing")]
276 tracing::error!(
277 src_addr = _peer_addr.to_string(),
278 dst_addr = original_dest.to_string(),
279 send = _stats.a_to_b_bytes,
280 recv = _stats.b_to_a_bytes,
281 status = "err",
282 error = %err,
283 "Connect"
284 );
285 return Err(err);
286 }
287 }
288
289 Ok(())
290 }
291
292 #[cfg(feature = "datagram")]
294 fn spawn_udp_handler(&self) -> JoinHandle<io::Result<()>> {
295 let context = DatagraContext::new(
296 Arc::clone(&self.connection),
297 self.peer_addr,
298 self.cancellation_token.child_token(),
299 );
300 tokio::spawn(async move { context.run_associate_loop().await })
301 }
302}
303
304#[cfg(feature = "datagram")]
305mod datagram {
306 use std::io;
307 use std::net::SocketAddr;
308 use std::sync::Arc;
309 use std::sync::atomic::{AtomicU32, Ordering};
310 use std::time::Duration;
311
312 use bytes::Bytes;
313 use moka::future::Cache;
314 use ombrac_macros::{debug, info, warn};
315 use tokio::net::UdpSocket;
316 use tokio::task::AbortHandle;
317 use tokio_util::sync::CancellationToken;
318
319 use ombrac::protocol::{Address, UdpPacket};
320 use ombrac::reassembly::UdpReassembler;
321 use ombrac_transport::Connection;
322
323 pub(super) struct DatagraContext<C: Connection> {
325 connection: Arc<C>,
326 peer_addr: SocketAddr,
327 token: CancellationToken,
328 session_sockets: Cache<u64, (Arc<UdpSocket>, AbortHandle)>,
329 dns_cache: Cache<Bytes, SocketAddr>,
330 reassembler: Arc<UdpReassembler>,
331 fragment_id_counter: Arc<AtomicU32>,
332 }
333
334 impl<C: Connection> DatagraContext<C> {
335 pub(super) fn new(
336 connection: Arc<C>,
337 peer_addr: SocketAddr,
338 token: CancellationToken,
339 ) -> Self {
340 Self {
341 connection,
342 peer_addr,
343 token,
344 session_sockets: Cache::builder()
345 .max_capacity(8192)
346 .time_to_idle(Duration::from_secs(65))
347 .eviction_listener(|_key, val: (Arc<UdpSocket>, AbortHandle), _cause| {
348 val.1.abort();
349 debug!("Session UDP socket evicted due to: {:?}", _cause);
350 })
351 .build(),
352 dns_cache: Cache::builder()
353 .time_to_idle(Duration::from_secs(300))
354 .build(),
355 reassembler: Arc::new(UdpReassembler::default()),
356 fragment_id_counter: Arc::new(AtomicU32::new(0)),
357 }
358 }
359
360 pub(super) async fn run_associate_loop(self) -> io::Result<()> {
362 loop {
363 tokio::select! {
364 _ = self.token.cancelled() => {
365 return Ok(());
366 }
367 result = self.connection.read_datagram() => {
369 let packet_bytes = match result {
370 Ok(bytes) => bytes,
371 Err(e) => {
372 if e.kind() == io::ErrorKind::TimedOut {
373 debug!("{} Idle timeout reading datagram from client. Continuing.", self.peer_addr);
374 continue;
375 }
376
377 warn!("{} Unrecoverable error reading datagram from client: {}. Closing UDP handler.", self.peer_addr, e);
378 return Err(e);
379 }
380 };
381 self.handle_upstream_packet(packet_bytes).await?;
382 }
383 }
384 }
385 }
386
387 async fn handle_upstream_packet(&self, packet_bytes: Bytes) -> io::Result<()> {
389 let packet = match UdpPacket::decode(&packet_bytes) {
390 Ok(p) => p,
391 Err(e) => {
392 warn!(
393 "{} Failed to decode UDP packet from client: {}",
394 self.peer_addr, e
395 );
396 return Ok(()); }
398 };
399
400 if let Some((session_id, address, data)) = self.reassembler.process(packet).await? {
402 let socket = self
404 .get_or_create_session_socket(session_id, &address)
405 .await?;
406 self.forward_to_destination(session_id, socket, address, data);
407 }
408 Ok(())
409 }
410
411 fn forward_to_destination(
413 &self,
414 session_id: u64,
415 socket: Arc<UdpSocket>,
416 address: Address,
417 data: Bytes,
418 ) {
419 let peer_addr = self.peer_addr;
420 let dns_cache = self.dns_cache.clone();
421
422 tokio::spawn(async move {
423 let dest_addr = match address {
424 Address::SocketV4(addr) => SocketAddr::V4(addr),
425 Address::SocketV6(addr) => SocketAddr::V6(addr),
426 Address::Domain(ref domain, port) => {
427 if let Some(addr) = dns_cache.get(domain).await {
428 addr
429 } else {
430 let domain_str = String::from_utf8_lossy(domain);
431 match tokio::net::lookup_host(format!("{}:{}", domain_str, port)).await
432 {
433 Ok(mut addrs) => {
434 if let Some(addr) = addrs.next() {
435 dns_cache.insert(domain.clone(), addr).await;
436 addr
437 } else {
438 warn!(
439 "{} [Session][{}] DNS resolution failed for {}",
440 peer_addr, session_id, domain_str
441 );
442 return;
443 }
444 }
445 Err(e) => {
446 warn!(
447 "{} [Session][{}] DNS resolution error for {}: {}",
448 peer_addr, session_id, domain_str, e
449 );
450 return;
451 }
452 }
453 }
454 }
455 };
456
457 if let Err(e) = socket.send_to(&data, dest_addr).await {
458 warn!(
459 "{} [Session] Failed to send packet to {}: {}",
460 peer_addr, dest_addr, e
461 );
462 }
463 });
464 }
465
466 async fn get_or_create_session_socket(
468 &self,
469 session_id: u64,
470 dest_addr: &Address,
471 ) -> io::Result<Arc<UdpSocket>> {
472 if let Some((socket, _)) = self.session_sockets.get(&session_id).await {
473 return Ok(socket);
474 }
475
476 let bind_addr = match dest_addr {
477 Address::SocketV4(_) => "0.0.0.0:0",
478 Address::SocketV6(_) => "[::]:0",
479 Address::Domain(domain_bytes, port) => {
480 if let Some(addr) = self.dns_cache.get(domain_bytes).await {
481 match addr {
482 SocketAddr::V4(_) => "0.0.0.0:0",
483 SocketAddr::V6(_) => "[::]:0",
484 }
485 } else {
486 let domain = format!("{}:{}", String::from_utf8_lossy(domain_bytes), port);
487 match tokio::net::lookup_host(&domain).await?.next() {
488 Some(sa) if sa.is_ipv4() => "0.0.0.0:0",
489 Some(_) => "[::]:0",
490 None => {
491 return Err(io::Error::new(
492 io::ErrorKind::NotFound,
493 format!("Domain name {domain} could not be resolved"),
494 ));
495 }
496 }
497 }
498 }
499 };
500
501 let new_socket = Arc::new(UdpSocket::bind(bind_addr).await?);
503
504 info!(
505 "{} [Session][{}] New session for {}, listening on {}",
506 self.peer_addr,
507 session_id,
508 dest_addr,
509 new_socket.local_addr()?
510 );
511
512 let abort_handle = self.spawn_downstream_task(session_id, Arc::clone(&new_socket));
514 self.session_sockets
515 .insert(session_id, (Arc::clone(&new_socket), abort_handle))
516 .await;
517
518 Ok(new_socket)
519 }
520
521 fn spawn_downstream_task(&self, session_id: u64, socket: Arc<UdpSocket>) -> AbortHandle {
523 let conn = Arc::clone(&self.connection);
524 let token = self.token.child_token();
525 let frag_counter = Arc::clone(&self.fragment_id_counter);
526 let peer_addr = self.peer_addr;
527
528 let handle = tokio::spawn(async move {
529 Self::run_downstream_task(conn, peer_addr, session_id, socket, frag_counter, token)
530 .await;
531 });
532
533 handle.abort_handle()
534 }
535
536 async fn run_downstream_task(
538 connection: Arc<C>,
539 peer_addr: SocketAddr,
540 session_id: u64,
541 socket: Arc<UdpSocket>,
542 fragment_id_counter: Arc<AtomicU32>,
543 token: CancellationToken,
544 ) {
545 let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
546 let overhead = UdpPacket::fragmented_overhead();
547 let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
548 let mut buf = vec![0u8; 65535];
549
550 loop {
551 tokio::select! {
552 _ = token.cancelled() => break,
553 result = socket.recv_from(&mut buf) => {
554 let (len, from_addr) = match result {
555 Ok(r) => r,
556 Err(e) => {
557 warn!("{} [Session][{}] Error receiving from remote socket: {}", peer_addr, session_id, e);
558 break;
559 }
560 };
561
562 let address = Address::from(from_addr);
563 let data = Bytes::copy_from_slice(&buf[..len]);
564
565 if data.len() <= max_payload_size {
567 let packet = UdpPacket::Unfragmented { session_id, address, data };
568 if let Ok(encoded) = packet.encode()
569 && connection.send_datagram(encoded).await.is_err() { break; }
570 } else {
571 let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
572 let fragments = UdpPacket::split_packet(session_id, address, data, max_payload_size, fragment_id);
573 for fragment in fragments {
574 if let Ok(encoded) = fragment.encode()
575 && connection.send_datagram(encoded).await.is_err() { break; }
576 }
577 }
578 }
579 }
580 }
581 }
582 }
583}