ombrac_server/
server.rs

1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use tokio::sync::broadcast;
5use tokio::task::JoinHandle;
6
7use futures::{SinkExt, StreamExt};
8use tokio::io::AsyncWriteExt;
9use tokio::net::TcpStream;
10use tokio_util::codec::Framed;
11use tokio_util::sync::CancellationToken;
12
13use ombrac::codec::{LengthDelimitedCodec, ServerHandshakeResponse, UpstreamMessage, length_codec};
14use ombrac::protocol::{self, HandshakeError, PROTOCOLS_VERSION, Secret};
15use ombrac_macros::{debug, error, info, warn};
16use ombrac_transport::{Acceptor, Connection};
17
18// Conditionally compiled datagram module. All UDP logic is encapsulated here.
19#[cfg(feature = "datagram")]
20use self::datagram::DatagraContext;
21
22/// The main server struct, responsible for accepting incoming connections and spawning handlers.
23pub struct Server<T: Acceptor> {
24    acceptor: Arc<T>,
25    secret: Secret,
26}
27
28impl<T: Acceptor> Server<T> {
29    /// Creates a new server instance.
30    pub fn new(acceptor: T, secret: Secret) -> Self {
31        Self {
32            acceptor: Arc::new(acceptor),
33            secret,
34        }
35    }
36
37    /// Runs the main server loop, accepting new connections.
38    pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
39        loop {
40            tokio::select! {
41                _ = shutdown_rx.recv() => {
42                    return Ok(());
43                }
44                // Accept a new connection.
45                accepted = self.acceptor.accept() => {
46                    match accepted {
47                        Ok(connection) => {
48                            let secret = self.secret;
49                            let peer_addr = connection.remote_address().unwrap_or_else(|_| "unknown".parse().unwrap());
50                            info!("{} Connection established", peer_addr);
51
52                            // Spawn a new task to handle the entire lifecycle of the client connection.
53                            tokio::spawn(async move {
54                                if let Err(e) = ConnectionHandler::handle(connection, secret, peer_addr).await {
55                                    if e.kind() != io::ErrorKind::ConnectionReset && e.kind() != io::ErrorKind::BrokenPipe && e.kind() != io::ErrorKind::UnexpectedEof {
56                                        error!("{} Connection handler failed: {}", peer_addr, e);
57                                    } else {
58                                        info!("{} Connection closed by peer", peer_addr);
59                                    }
60                                }
61                            });
62                        },
63                        Err(_e) => {
64                            error!("Failed to accept connection: {}", _e)
65                        },
66                    }
67                },
68            }
69        }
70    }
71
72    pub fn local_addr(&self) -> io::Result<SocketAddr> {
73        self.acceptor.local_addr()
74    }
75}
76
77/// Manages the lifecycle of a single client connection.
78///
79/// This struct is responsible for performing the initial handshake and then spawning
80/// and managing the tasks for handling TCP and UDP traffic for the duration of the connection.
81struct ConnectionHandler<C: Connection> {
82    connection: Arc<C>,
83    peer_addr: SocketAddr,
84    cancellation_token: CancellationToken,
85}
86
87impl<C: Connection> ConnectionHandler<C> {
88    /// The main entry point for handling a new client connection.
89    pub async fn handle(connection: C, secret: Secret, peer_addr: SocketAddr) -> io::Result<()> {
90        // Step 1: Perform the handshake to authenticate the client.
91        let mut control_stream = connection.accept_bidirectional().await?;
92        let mut framed_control = Framed::new(&mut control_stream, length_codec());
93
94        match framed_control.next().await {
95            Some(Ok(payload)) => {
96                let hello_message: UpstreamMessage = protocol::decode(&payload)?;
97                Self::validate_handshake(hello_message, secret, peer_addr, &mut framed_control)
98                    .await?;
99            }
100            _ => {
101                return Err(io::Error::new(
102                    io::ErrorKind::InvalidData,
103                    "Failed to read Hello message",
104                ));
105            }
106        }
107
108        // Step 2: Set up the handler and run the proxy tasks.
109        let handler = Self {
110            connection: Arc::new(connection),
111            peer_addr,
112            cancellation_token: CancellationToken::new(),
113        };
114
115        handler.run_proxy_tasks().await;
116        Ok(())
117    }
118
119    /// Validates the client's Hello message and sends the appropriate response.
120    async fn validate_handshake(
121        message: UpstreamMessage,
122        secret: Secret,
123        _peer_addr: SocketAddr,
124        framed: &mut Framed<&mut C::Stream, LengthDelimitedCodec>,
125    ) -> io::Result<()> {
126        if let UpstreamMessage::Hello(hello) = message {
127            let response = if hello.version != PROTOCOLS_VERSION {
128                warn!(
129                    "{} Handshake failed: Unsupported protocol version",
130                    _peer_addr
131                );
132                ServerHandshakeResponse::Err(HandshakeError::UnsupportedVersion)
133            } else if hello.secret != secret {
134                warn!("{} Handshake failed: Invalid secret", _peer_addr);
135                ServerHandshakeResponse::Err(HandshakeError::InvalidSecret)
136            } else {
137                debug!("{} Handshake successful", _peer_addr);
138                ServerHandshakeResponse::Ok
139            };
140
141            let response_bytes = protocol::encode(&response)?;
142            framed.send(response_bytes).await?;
143
144            if matches!(response, ServerHandshakeResponse::Err(_)) {
145                return Err(io::Error::new(
146                    io::ErrorKind::PermissionDenied,
147                    "Handshake failed",
148                ));
149            }
150            Ok(())
151        } else {
152            Err(io::Error::new(
153                io::ErrorKind::InvalidData,
154                "Expected Hello message",
155            ))
156        }
157    }
158
159    /// Spawns and manages the long-running tasks for TCP and UDP proxying.
160    ///
161    /// This function waits for either the TCP or UDP handler to exit (due to an
162    /// error or graceful shutdown) and then triggers a cancellation for all other
163    /// tasks associated with this connection.
164    async fn run_proxy_tasks(&self) {
165        let tcp_handler = self.spawn_tcp_handler();
166
167        #[cfg(feature = "datagram")]
168        let udp_handler = self.spawn_udp_handler();
169
170        // Wait for either handler to complete or fail.
171        #[cfg(feature = "datagram")]
172        let result = tokio::select! {
173            res = tcp_handler => res,
174            res = udp_handler => res,
175        };
176
177        // If datagram feature is disabled, only await the TCP handler.
178        #[cfg(not(feature = "datagram"))]
179        let result = tcp_handler.await;
180
181        self.cancellation_token.cancel(); // Signal all related tasks to shut down.
182
183        match result {
184            Ok(Ok(_)) => {
185                debug!("{} Client connection closed gracefully.", self.peer_addr);
186            }
187            Ok(Err(e)) => {
188                if e.kind() != io::ErrorKind::ConnectionAborted {
189                    warn!(
190                        "{} Client connection closed with an error: {}",
191                        self.peer_addr, e
192                    );
193                }
194            }
195            Err(_join_err) => {
196                warn!(
197                    "{} Client connection handler task failed: {}",
198                    self.peer_addr, _join_err
199                );
200            }
201        }
202    }
203
204    /// Spawns the task responsible for accepting and handling new TCP streams.
205    fn spawn_tcp_handler(&self) -> JoinHandle<io::Result<()>> {
206        let connection = Arc::clone(&self.connection);
207        let peer_addr = self.peer_addr;
208        let token = self.cancellation_token.child_token();
209
210        tokio::spawn(async move {
211            loop {
212                tokio::select! {
213                    _ = token.cancelled() => return Ok(()),
214                    result = connection.accept_bidirectional() => {
215                        let stream = result?;
216
217                        // Spawn a separate task for each TCP stream to avoid blocking the acceptor.
218                        tokio::spawn(async move {
219                            if let Err(_e) = Self::handle_tcp_stream(stream, peer_addr).await {
220                                warn!("{} Stream handler error: {}", peer_addr, _e);
221                            }
222                        });
223                    }
224                }
225            }
226        })
227    }
228
229    /// Handles a single TCP stream, proxying data to the requested destination.
230    async fn handle_tcp_stream(mut stream: C::Stream, _peer_addr: SocketAddr) -> io::Result<()> {
231        let mut framed = Framed::new(&mut stream, length_codec());
232
233        // Read the destination address from the client.
234        let original_dest = match framed.next().await {
235            Some(Ok(payload)) => match protocol::decode(&payload)? {
236                UpstreamMessage::Connect(connect) => connect.address,
237                _ => {
238                    return Err(io::Error::new(
239                        io::ErrorKind::InvalidData,
240                        "Expected Connect message",
241                    ));
242                }
243            },
244            _ => {
245                return Err(io::Error::new(
246                    io::ErrorKind::InvalidData,
247                    "Failed to read Connect message on new stream",
248                ));
249            }
250        };
251
252        let mut dest_stream = TcpStream::connect(original_dest.to_socket_addr().await?).await?;
253
254        // Forward any data that was already buffered in the framing codec.
255        let parts = framed.into_parts();
256        let mut stream = parts.io;
257        if !parts.read_buf.is_empty() {
258            dest_stream.write_all(&parts.read_buf).await?;
259        }
260
261        // Copy data in both directions until one side closes.
262        match ombrac_transport::io::copy_bidirectional(&mut stream, &mut dest_stream).await {
263            Ok(_stats) => {
264                #[cfg(feature = "tracing")]
265                tracing::info!(
266                    src_addr = _peer_addr.to_string(),
267                    dst_addr = original_dest.to_string(),
268                    send = _stats.a_to_b_bytes,
269                    recv = _stats.b_to_a_bytes,
270                    status = "ok",
271                    "Connect"
272                );
273            }
274            Err((err, _stats)) => {
275                #[cfg(feature = "tracing")]
276                tracing::error!(
277                    src_addr = _peer_addr.to_string(),
278                    dst_addr = original_dest.to_string(),
279                    send = _stats.a_to_b_bytes,
280                    recv = _stats.b_to_a_bytes,
281                    status = "err",
282                    error = %err,
283                    "Connect"
284                );
285                return Err(err);
286            }
287        }
288
289        Ok(())
290    }
291
292    /// Spawns the task responsible for handling all UDP traffic.
293    #[cfg(feature = "datagram")]
294    fn spawn_udp_handler(&self) -> JoinHandle<io::Result<()>> {
295        let context = DatagraContext::new(
296            Arc::clone(&self.connection),
297            self.peer_addr,
298            self.cancellation_token.child_token(),
299        );
300        tokio::spawn(async move { context.run_associate_loop().await })
301    }
302}
303
304#[cfg(feature = "datagram")]
305mod datagram {
306    use std::io;
307    use std::net::SocketAddr;
308    use std::sync::Arc;
309    use std::sync::atomic::{AtomicU32, Ordering};
310    use std::time::Duration;
311
312    use bytes::Bytes;
313    use moka::future::Cache;
314    use ombrac_macros::{debug, info, warn};
315    use tokio::net::UdpSocket;
316    use tokio::task::AbortHandle;
317    use tokio_util::sync::CancellationToken;
318
319    use ombrac::protocol::{Address, UdpPacket};
320    use ombrac::reassembly::UdpReassembler;
321    use ombrac_transport::Connection;
322
323    /// Contains the shared state for a connection's UDP proxy.
324    pub(super) struct DatagraContext<C: Connection> {
325        connection: Arc<C>,
326        peer_addr: SocketAddr,
327        token: CancellationToken,
328        session_sockets: Cache<u64, (Arc<UdpSocket>, AbortHandle)>,
329        dns_cache: Cache<Bytes, SocketAddr>,
330        reassembler: Arc<UdpReassembler>,
331        fragment_id_counter: Arc<AtomicU32>,
332    }
333
334    impl<C: Connection> DatagraContext<C> {
335        pub(super) fn new(
336            connection: Arc<C>,
337            peer_addr: SocketAddr,
338            token: CancellationToken,
339        ) -> Self {
340            Self {
341                connection,
342                peer_addr,
343                token,
344                session_sockets: Cache::builder()
345                    .max_capacity(8192)
346                    .time_to_idle(Duration::from_secs(65))
347                    .eviction_listener(|_key, val: (Arc<UdpSocket>, AbortHandle), _cause| {
348                        val.1.abort();
349                        debug!("Session UDP socket evicted due to: {:?}", _cause);
350                    })
351                    .build(),
352                dns_cache: Cache::builder()
353                    .time_to_idle(Duration::from_secs(300))
354                    .build(),
355                reassembler: Arc::new(UdpReassembler::default()),
356                fragment_id_counter: Arc::new(AtomicU32::new(0)),
357            }
358        }
359
360        /// Runs the main UDP proxy loop, handling upstream data from the client.
361        pub(super) async fn run_associate_loop(self) -> io::Result<()> {
362            loop {
363                tokio::select! {
364                    _ = self.token.cancelled() => {
365                        return Ok(());
366                    }
367                    // Read a raw datagram from the client connection.
368                    result = self.connection.read_datagram() => {
369                        let packet_bytes = match result {
370                            Ok(bytes) => bytes,
371                            Err(e) => {
372                                if e.kind() == io::ErrorKind::TimedOut {
373                                    debug!("{} Idle timeout reading datagram from client. Continuing.", self.peer_addr);
374                                    continue;
375                                }
376
377                                warn!("{} Unrecoverable error reading datagram from client: {}. Closing UDP handler.", self.peer_addr, e);
378                                return Err(e);
379                            }
380                        };
381                        self.handle_upstream_packet(packet_bytes).await?;
382                    }
383                }
384            }
385        }
386
387        /// Processes a single raw packet received from the client.
388        async fn handle_upstream_packet(&self, packet_bytes: Bytes) -> io::Result<()> {
389            let packet = match UdpPacket::decode(&packet_bytes) {
390                Ok(p) => p,
391                Err(e) => {
392                    warn!(
393                        "{} Failed to decode UDP packet from client: {}",
394                        self.peer_addr, e
395                    );
396                    return Ok(()); // Skip malformed packet
397                }
398            };
399
400            // Process the packet through the reassembler. It returns a full datagram when ready.
401            if let Some((session_id, address, data)) = self.reassembler.process(packet).await? {
402                // Get or create the UDP socket for this session.
403                let socket = self
404                    .get_or_create_session_socket(session_id, &address)
405                    .await?;
406                self.forward_to_destination(session_id, socket, address, data);
407            }
408            Ok(())
409        }
410
411        /// Forwards a reassembled datagram to its final destination.
412        fn forward_to_destination(
413            &self,
414            session_id: u64,
415            socket: Arc<UdpSocket>,
416            address: Address,
417            data: Bytes,
418        ) {
419            let peer_addr = self.peer_addr;
420            let dns_cache = self.dns_cache.clone();
421
422            tokio::spawn(async move {
423                let dest_addr = match address {
424                    Address::SocketV4(addr) => SocketAddr::V4(addr),
425                    Address::SocketV6(addr) => SocketAddr::V6(addr),
426                    Address::Domain(ref domain, port) => {
427                        if let Some(addr) = dns_cache.get(domain).await {
428                            addr
429                        } else {
430                            let domain_str = String::from_utf8_lossy(domain);
431                            match tokio::net::lookup_host(format!("{}:{}", domain_str, port)).await
432                            {
433                                Ok(mut addrs) => {
434                                    if let Some(addr) = addrs.next() {
435                                        dns_cache.insert(domain.clone(), addr).await;
436                                        addr
437                                    } else {
438                                        warn!(
439                                            "{} [Session][{}] DNS resolution failed for {}",
440                                            peer_addr, session_id, domain_str
441                                        );
442                                        return;
443                                    }
444                                }
445                                Err(e) => {
446                                    warn!(
447                                        "{} [Session][{}] DNS resolution error for {}: {}",
448                                        peer_addr, session_id, domain_str, e
449                                    );
450                                    return;
451                                }
452                            }
453                        }
454                    }
455                };
456
457                if let Err(e) = socket.send_to(&data, dest_addr).await {
458                    warn!(
459                        "{} [Session] Failed to send packet to {}: {}",
460                        peer_addr, dest_addr, e
461                    );
462                }
463            });
464        }
465
466        /// Retrieves an existing UDP socket for a session or creates a new one.
467        async fn get_or_create_session_socket(
468            &self,
469            session_id: u64,
470            dest_addr: &Address,
471        ) -> io::Result<Arc<UdpSocket>> {
472            if let Some((socket, _)) = self.session_sockets.get(&session_id).await {
473                return Ok(socket);
474            }
475
476            let bind_addr = match dest_addr {
477                Address::SocketV4(_) => "0.0.0.0:0",
478                Address::SocketV6(_) => "[::]:0",
479                Address::Domain(domain_bytes, port) => {
480                    if let Some(addr) = self.dns_cache.get(domain_bytes).await {
481                        match addr {
482                            SocketAddr::V4(_) => "0.0.0.0:0",
483                            SocketAddr::V6(_) => "[::]:0",
484                        }
485                    } else {
486                        let domain = format!("{}:{}", String::from_utf8_lossy(domain_bytes), port);
487                        match tokio::net::lookup_host(&domain).await?.next() {
488                            Some(sa) if sa.is_ipv4() => "0.0.0.0:0",
489                            Some(_) => "[::]:0",
490                            None => {
491                                return Err(io::Error::new(
492                                    io::ErrorKind::NotFound,
493                                    format!("Domain name {domain} could not be resolved"),
494                                ));
495                            }
496                        }
497                    }
498                }
499            };
500
501            // Create a new UDP socket bound to a random port.
502            let new_socket = Arc::new(UdpSocket::bind(bind_addr).await?);
503
504            info!(
505                "{} [Session][{}] New session for {}, listening on {}",
506                self.peer_addr,
507                session_id,
508                dest_addr,
509                new_socket.local_addr()?
510            );
511
512            // Spawn a task to handle downstream traffic (from destination back to client).
513            let abort_handle = self.spawn_downstream_task(session_id, Arc::clone(&new_socket));
514            self.session_sockets
515                .insert(session_id, (Arc::clone(&new_socket), abort_handle))
516                .await;
517
518            Ok(new_socket)
519        }
520
521        /// Spawns a dedicated task for handling downstream traffic for a single UDP session.
522        fn spawn_downstream_task(&self, session_id: u64, socket: Arc<UdpSocket>) -> AbortHandle {
523            let conn = Arc::clone(&self.connection);
524            let token = self.token.child_token();
525            let frag_counter = Arc::clone(&self.fragment_id_counter);
526            let peer_addr = self.peer_addr;
527
528            let handle = tokio::spawn(async move {
529                Self::run_downstream_task(conn, peer_addr, session_id, socket, frag_counter, token)
530                    .await;
531            });
532
533            handle.abort_handle()
534        }
535
536        /// Task that reads from a remote UDP socket and forwards data back to the client.
537        async fn run_downstream_task(
538            connection: Arc<C>,
539            peer_addr: SocketAddr,
540            session_id: u64,
541            socket: Arc<UdpSocket>,
542            fragment_id_counter: Arc<AtomicU32>,
543            token: CancellationToken,
544        ) {
545            let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
546            let overhead = UdpPacket::fragmented_overhead();
547            let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
548            let mut buf = vec![0u8; 65535];
549
550            loop {
551                tokio::select! {
552                    _ = token.cancelled() => break,
553                    result = socket.recv_from(&mut buf) => {
554                        let (len, from_addr) = match result {
555                            Ok(r) => r,
556                            Err(e) => {
557                                warn!("{} [Session][{}] Error receiving from remote socket: {}", peer_addr, session_id, e);
558                                break;
559                            }
560                        };
561
562                        let address = Address::from(from_addr);
563                        let data = Bytes::copy_from_slice(&buf[..len]);
564
565                        // This packet might need to be fragmented before sending back to client.
566                        if data.len() <= max_payload_size {
567                            let packet = UdpPacket::Unfragmented { session_id, address, data };
568                            if let Ok(encoded) = packet.encode()
569                                && connection.send_datagram(encoded).await.is_err() { break; }
570                        } else {
571                            let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
572                            let fragments = UdpPacket::split_packet(session_id, address, data, max_payload_size, fragment_id);
573                            for fragment in fragments {
574                                if let Ok(encoded) = fragment.encode()
575                                    && connection.send_datagram(encoded).await.is_err() { break; }
576                            }
577                        }
578                    }
579                }
580            }
581        }
582    }
583}