Skip to main content

ombrac_server/connection/
mod.rs

1#[cfg(feature = "datagram")]
2mod datagram;
3mod dns;
4mod stream;
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::Weak;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13
14use futures::{SinkExt, StreamExt};
15use tokio::sync::OwnedSemaphorePermit;
16use tokio::sync::{Semaphore, broadcast};
17use tokio::task::JoinHandle;
18use tokio_util::codec::Framed;
19use tokio_util::sync::CancellationToken;
20#[cfg(feature = "tracing")]
21use tracing::Instrument;
22
23use ombrac::codec;
24use ombrac::metrics::Metrics;
25use ombrac::protocol;
26use ombrac_macros::{debug, error, warn};
27use ombrac_transport::{Acceptor, Connection};
28
29use crate::config::ConnectionConfig;
30
31/// Processes a single client connection, handling authentication and tunnel management.
32///
33/// This struct manages the lifecycle of a client connection after it has been
34/// accepted by the server. It performs authentication, sets up tunnel handlers
35/// for streams and datagrams, and manages the connection until it closes.
36pub struct ClientConnectionProcessor<C: Connection> {
37    transport_connection: Arc<C>,
38    shutdown_token: CancellationToken,
39    metrics: Metrics,
40}
41
42impl<C: Connection> ClientConnectionProcessor<C> {
43    /// Handles a new client connection from authentication through tunnel setup.
44    ///
45    /// This method:
46    /// 1. Performs the authentication
47    /// 2. Notifies the authenticator that the connection is accepted
48    /// 3. Sets up and runs tunnel loops for streams and datagrams
49    pub async fn handle<A>(
50        connection: C,
51        authenticator: &A,
52        config: Arc<ConnectionConfig>,
53        metrics: &Metrics,
54    ) -> io::Result<()>
55    where
56        A: Authenticator<C>,
57    {
58        let (auth_context, connection) =
59            Self::perform_authentication(connection, authenticator, &config).await?;
60
61        let transport_connection = Arc::new(connection);
62
63        authenticator
64            .accept(
65                auth_context,
66                ConnectionHandle {
67                    inner: transport_connection.clone(),
68                },
69            )
70            .await;
71
72        let processor = Self {
73            transport_connection,
74            shutdown_token: CancellationToken::new(),
75            metrics: metrics.clone(),
76        };
77
78        processor.run_tunnel_loops().await;
79
80        Ok(())
81    }
82
83    async fn perform_authentication<A: Authenticator<C>>(
84        connection: C,
85        authenticator: &A,
86        config: &ConnectionConfig,
87    ) -> io::Result<(A::AuthContext, C)> {
88        let auth_timeout = Duration::from_secs(config.auth_timeout_secs());
89
90        // Accept control stream
91        let mut control_stream = connection.accept_bidirectional().await.map_err(|e| {
92            io::Error::other(format!("failed to accept bidirectional stream: {}", e))
93        })?;
94        let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
95
96        // Read and parse hello message
97        let hello = Self::read_hello_message(&mut control_frame, auth_timeout).await?;
98
99        #[cfg(feature = "tracing")]
100        Self::trace_auth(&hello);
101
102        // Verify authentication
103        let auth_context =
104            Self::verify_authentication(&hello, authenticator, auth_timeout, &mut control_frame)
105                .await?;
106
107        Ok((auth_context, connection))
108    }
109
110    /// Reads and parses the hello message from the client.
111    async fn read_hello_message(
112        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
113        timeout: Duration,
114    ) -> io::Result<protocol::ClientHello>
115    where
116        C: Connection,
117    {
118        // Read payload with timeout
119        let payload = tokio::time::timeout(timeout, control_frame.next())
120            .await
121            .map_err(|_| {
122                io::Error::new(
123                    io::ErrorKind::TimedOut,
124                    format!(
125                        "authentication timeout: failed to receive hello message within {:?}",
126                        timeout
127                    ),
128                )
129            })?
130            .ok_or_else(|| {
131                io::Error::new(io::ErrorKind::UnexpectedEof, "stream closed before hello")
132            })??;
133
134        // Decode message
135        let message: codec::ClientMessage = protocol::decode(&payload).map_err(|e| {
136            io::Error::new(
137                io::ErrorKind::InvalidData,
138                format!("failed to decode client message: {}", e),
139            )
140        })?;
141
142        // Extract hello message
143        match message {
144            codec::ClientMessage::Hello(hello) => Ok(hello),
145            _ => {
146                // Invalid message type - disconnect with random delay
147                let stream = control_frame.get_mut();
148                Self::disconnect_with_random_delay(*stream).await;
149                Err(io::Error::new(
150                    io::ErrorKind::InvalidData,
151                    "authentication failed: invalid message type (expected Hello)",
152                ))
153            }
154        }
155    }
156
157    /// Verifies authentication and sends response.
158    async fn verify_authentication<A: Authenticator<C>>(
159        hello: &protocol::ClientHello,
160        authenticator: &A,
161        timeout: Duration,
162        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
163    ) -> io::Result<A::AuthContext>
164    where
165        C: Connection,
166    {
167        // Check protocol version
168        if hello.version != protocol::PROTOCOL_VERSION {
169            Self::handle_auth_failure(control_frame).await;
170            return Err(io::Error::new(
171                io::ErrorKind::PermissionDenied,
172                "incompatible version",
173            ));
174        }
175
176        // Perform authentication with timeout
177        let auth_context = tokio::time::timeout(timeout, authenticator.verify(hello)).await??;
178
179        Self::send_auth_ok_response(control_frame, timeout).await?;
180
181        Ok(auth_context)
182    }
183
184    /// Sends authentication response with timeout.
185    async fn send_auth_ok_response(
186        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
187        timeout: Duration,
188    ) -> io::Result<()>
189    where
190        C: Connection,
191    {
192        tokio::time::timeout(
193            timeout,
194            control_frame.send(protocol::encode(&protocol::ServerAuthResponse::Ok)?),
195        )
196        .await
197        .map_err(|_| {
198            io::Error::new(
199                io::ErrorKind::TimedOut,
200                format!(
201                    "authentication timeout: failed to send response within {:?}",
202                    timeout
203                ),
204            )
205        })??;
206        Ok(())
207    }
208
209    /// Handles authentication failure by disconnecting with random delay.
210    async fn handle_auth_failure(
211        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
212    ) where
213        C: Connection,
214    {
215        // Get the underlying stream for disconnection
216        let stream = control_frame.get_mut();
217        Self::disconnect_with_random_delay(*stream).await;
218    }
219
220    /// Disconnects the stream with a random delay to prevent timing attacks.
221    ///
222    /// This function introduces a random delay (100-500ms) before closing the stream,
223    /// making it harder for attackers to distinguish between different failure modes
224    /// based on response timing.
225    async fn disconnect_with_random_delay(stream: &mut C::Stream) {
226        use rand::RngExt;
227
228        let delay_ms = {
229            let mut rng = rand::rng();
230            rng.random_range(100..=500)
231        };
232
233        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
234        let _ = tokio::io::AsyncWriteExt::shutdown(stream).await;
235    }
236
237    /// Runs the tunnel loops for streams and datagrams until the connection closes.
238    async fn run_tunnel_loops(&self) {
239        let stream_tunnel_handle = self.spawn_stream_tunnel();
240        #[cfg(feature = "datagram")]
241        let datagram_tunnel_handle = self.spawn_datagram_tunnel();
242
243        #[cfg(not(feature = "datagram"))]
244        let result = stream_tunnel_handle.await;
245
246        #[cfg(feature = "datagram")]
247        let result = tokio::select! {
248            res = stream_tunnel_handle => res,
249            res = datagram_tunnel_handle => res,
250        };
251
252        // Signal all related tasks to shut down
253        self.shutdown_token.cancel();
254
255        match result {
256            Ok(Ok(_)) => debug!("connection closed gracefully"),
257            Ok(Err(e)) => debug!("connection closed with internal error: {}", e),
258            Err(e) => warn!("tunnel handler task panicked or failed: {}", e),
259        }
260    }
261
262    fn spawn_stream_tunnel(&self) -> JoinHandle<io::Result<()>> {
263        use crate::connection::stream::StreamTunnel;
264
265        let connection = Arc::clone(&self.transport_connection);
266        let shutdown = self.shutdown_token.child_token();
267        let tunnel = StreamTunnel::new(connection, shutdown, self.metrics.clone());
268
269        #[cfg(not(feature = "tracing"))]
270        let handle = tokio::spawn(tunnel.accept_loop());
271        #[cfg(feature = "tracing")]
272        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
273
274        handle
275    }
276
277    #[cfg(feature = "datagram")]
278    fn spawn_datagram_tunnel(&self) -> JoinHandle<io::Result<()>> {
279        use crate::connection::datagram::DatagramTunnel;
280
281        let connection = Arc::clone(&self.transport_connection);
282        let shutdown = self.shutdown_token.child_token();
283        let tunnel = DatagramTunnel::new(connection, shutdown, self.metrics.clone());
284
285        #[cfg(not(feature = "tracing"))]
286        let handle = tokio::spawn(tunnel.accept_loop());
287        #[cfg(feature = "tracing")]
288        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
289
290        handle
291    }
292
293    #[cfg(feature = "tracing")]
294    fn trace_auth(hello: &protocol::ClientHello) {
295        use std::io::Write;
296
297        let mut buf = [0u8; 6];
298        let mut cursor = std::io::Cursor::new(&mut buf[..]);
299
300        for byte in hello.secret.iter().take(3) {
301            let _ = write!(cursor, "{:02x}", byte);
302        }
303
304        if let Ok(hex_str) = std::str::from_utf8(&buf) {
305            tracing::Span::current().record("secret", hex_str);
306        }
307    }
308}
309
310/// ConnectionAcceptor manages incoming connections with resource limits and lifecycle control.
311///
312/// This struct handles accepting new connections from the transport layer,
313/// limiting concurrent connections, and delegating each connection to a processor.
314///
315/// Generic parameters:
316/// - `T`: The acceptor type that accepts new connections from the transport
317/// - `A`: The authenticator type that handles connection authentication
318pub struct ConnectionAcceptor<T, A> {
319    acceptor: Arc<T>,
320    authenticator: Arc<A>,
321    connection_semaphore: Arc<Semaphore>,
322    config: Arc<ConnectionConfig>,
323    metrics: Metrics,
324}
325
326impl<T: Acceptor, A: Authenticator<T::Connection> + 'static> ConnectionAcceptor<T, A> {
327    /// Creates a new connection acceptor with the given acceptor and authenticator.
328    ///
329    /// The acceptor will use default connection configuration if not provided.
330    pub fn new(acceptor: T, authenticator: A) -> Self {
331        Self::with_config(
332            acceptor,
333            authenticator,
334            Arc::new(ConnectionConfig::default()),
335        )
336    }
337
338    /// Creates a new connection acceptor with custom connection configuration.
339    pub fn with_config(acceptor: T, authenticator: A, config: Arc<ConnectionConfig>) -> Self {
340        let max_connections = config.max_connections();
341        Self {
342            acceptor: Arc::new(acceptor),
343            authenticator: Arc::new(authenticator),
344            connection_semaphore: Arc::new(Semaphore::new(max_connections)),
345            config,
346            metrics: Metrics::new(),
347        }
348    }
349
350    /// Returns a clone-able handle to runtime metrics.
351    ///
352    /// Counters are incremented as connections/streams flow through this acceptor;
353    /// callers can `metrics.snapshot()` at any time to read them.
354    pub fn metrics(&self) -> Metrics {
355        self.metrics.clone()
356    }
357
358    /// Main accept loop that accepts incoming connections and manages them with resource limits.
359    ///
360    /// This method will:
361    /// 1. Accept new connections from the acceptor
362    /// 2. Acquire a semaphore permit to limit total concurrent connections
363    /// 3. Spawn a task to handle each connection
364    /// 4. Gracefully handle shutdown signals
365    pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
366        loop {
367            tokio::select! {
368                _ = shutdown_rx.recv() => {
369                    break;
370                },
371                accepted = self.acceptor.accept() => {
372                    Self::handle_incoming_connection(
373                        accepted,
374                        Arc::clone(&self.authenticator),
375                        Arc::clone(&self.connection_semaphore),
376                        Arc::clone(&self.config),
377                        self.metrics.clone(),
378                    );
379                },
380            }
381        }
382
383        Ok(())
384    }
385
386    /// Handles an incoming connection, either spawning a processor or rejecting it.
387    fn handle_incoming_connection(
388        result: io::Result<<T as Acceptor>::Connection>,
389        authenticator: Arc<A>,
390        semaphore: Arc<Semaphore>,
391        config: Arc<ConnectionConfig>,
392        metrics: Metrics,
393    ) {
394        match result {
395            Ok(connection) => match semaphore.try_acquire_owned() {
396                Ok(permit) => {
397                    metrics
398                        .counters()
399                        .connections_accepted
400                        .fetch_add(1, Ordering::Relaxed);
401                    #[cfg(not(feature = "tracing"))]
402                    tokio::spawn(Self::process_connection_with_permit(
403                        connection,
404                        authenticator,
405                        permit,
406                        config,
407                        metrics,
408                    ));
409                    #[cfg(feature = "tracing")]
410                    tokio::spawn(
411                        Self::process_connection_with_permit(
412                            connection,
413                            authenticator,
414                            permit,
415                            config,
416                            metrics,
417                        )
418                        .in_current_span(),
419                    );
420                }
421                Err(_) => {
422                    metrics
423                        .counters()
424                        .connections_rejected
425                        .fetch_add(1, Ordering::Relaxed);
426                    warn!(
427                        "connection rejected: maximum concurrent connections ({}) reached",
428                        config.max_connections()
429                    );
430                }
431            },
432            Err(err) => {
433                error!("failed to accept connection: {}", err);
434            }
435        }
436    }
437
438    /// Processes a connection with a semaphore permit.
439    ///
440    /// The permit is automatically released when the connection is closed.
441    async fn process_connection_with_permit(
442        connection: <T as Acceptor>::Connection,
443        authenticator: Arc<A>,
444        _permit: OwnedSemaphorePermit,
445        config: Arc<ConnectionConfig>,
446        metrics: Metrics,
447    ) {
448        // Permit is held for the lifetime of this function
449        Self::process_connection(connection, authenticator, config, metrics).await;
450        // Permit is automatically released when dropped
451    }
452
453    #[cfg_attr(feature = "tracing",
454        tracing::instrument(
455            name = "connection",
456            skip_all,
457            fields(
458                id = connection.id(),
459                from = tracing::field::Empty,
460                secret = tracing::field::Empty,
461                reason = tracing::field::Empty
462            )
463        )
464    )]
465    async fn process_connection(
466        connection: <T as Acceptor>::Connection,
467        authenticator: Arc<A>,
468        config: Arc<ConnectionConfig>,
469        metrics: Metrics,
470    ) {
471        #[cfg(feature = "tracing")]
472        if let Ok(addr) = connection.remote_address() {
473            tracing::Span::current().record("from", tracing::field::display(addr));
474        }
475
476        let _result =
477            ClientConnectionProcessor::handle(connection, authenticator.as_ref(), config, &metrics)
478                .await;
479
480        if _result.is_err() {
481            metrics
482                .counters()
483                .connections_auth_failed
484                .fetch_add(1, Ordering::Relaxed);
485        }
486
487        #[cfg(feature = "tracing")]
488        match _result {
489            Ok(_) => {
490                tracing::Span::current().record("reason", "ok");
491                tracing::info!("connection closed");
492            }
493            Err(e) => {
494                tracing::Span::current().record("reason", tracing::field::display(&e));
495                tracing::error!(error = %e, "connection closed with error");
496            }
497        }
498    }
499
500    pub fn local_addr(&self) -> io::Result<SocketAddr> {
501        self.acceptor.local_addr()
502    }
503}
504
505pub struct ConnectionHandle<C> {
506    inner: Arc<C>,
507}
508
509impl<C: Connection> ConnectionHandle<C> {
510    pub fn downgrade_inner(&self) -> Weak<C> {
511        Arc::downgrade(&self.inner)
512    }
513
514    pub fn close(&self, error_code: u32, reason: &[u8]) {
515        self.inner.close(error_code, reason);
516    }
517}
518
519/// Authentication error types returned by the server during authentication.
520#[derive(Debug, Clone, PartialEq, Eq)]
521pub enum ConnectionAuthError {
522    /// Client protocol version is not supported by the server.
523    IncompatibleVersion,
524    /// Authentication secret is invalid or incorrect.
525    InvalidSecret,
526    /// Internal server error during authentication processing.
527    ServerError,
528    /// Other error
529    Other(String),
530}
531
532impl From<ConnectionAuthError> for io::Error {
533    fn from(value: ConnectionAuthError) -> Self {
534        match value {
535            ConnectionAuthError::IncompatibleVersion => {
536                io::Error::new(io::ErrorKind::Unsupported, "incompatible protocol version")
537            }
538            ConnectionAuthError::InvalidSecret => io::Error::new(
539                io::ErrorKind::PermissionDenied,
540                "invalid authentication secret",
541            ),
542            ConnectionAuthError::ServerError => io::Error::new(
543                io::ErrorKind::ConnectionAborted,
544                "internal server error during auth",
545            ),
546            ConnectionAuthError::Other(msg) => io::Error::other(msg),
547        }
548    }
549}
550
551/// Authenticator trait for verifying and accepting client connections.
552///
553/// This trait provides authentication logic for incoming connections.
554/// Implementations should verify the client's credentials and optionally
555/// perform any setup needed when a connection is accepted.
556pub trait Authenticator<T>: Send + Sync {
557    /// Context type that can be passed from verification to acceptance.
558    type AuthContext: Send;
559
560    /// Verifies the client's hello message and returns an authentication context.
561    ///
562    /// This method is called during the authentication phase to verify the client's
563    /// credentials. If verification succeeds, it returns a context that will
564    /// be passed to `accept`.
565    fn verify(
566        &self,
567        hello: &protocol::ClientHello,
568    ) -> impl Future<Output = Result<Self::AuthContext, ConnectionAuthError>> + Send;
569
570    /// Called after successful authentication to handle the accepted connection.
571    ///
572    /// This method is called with the authentication context from `verify` and
573    /// a handle to the connection. Implementations can use this to perform any
574    /// additional setup or logging.
575    fn accept(
576        &self,
577        auth_context: Self::AuthContext,
578        connection: ConnectionHandle<T>,
579    ) -> impl Future<Output = ()> + Send;
580}
581
582impl<T: Send + Sync> Authenticator<T> for ombrac::protocol::Secret {
583    type AuthContext = ();
584
585    async fn verify(&self, hello: &protocol::ClientHello) -> Result<(), ConnectionAuthError> {
586        if &hello.secret == self {
587            Ok(())
588        } else {
589            Err(ConnectionAuthError::InvalidSecret)
590        }
591    }
592
593    async fn accept(&self, _auth_context: Self::AuthContext, _connection: ConnectionHandle<T>) {}
594}