Skip to main content

ombrac_server/connection/
mod.rs

1#[cfg(feature = "datagram")]
2mod datagram;
3mod dns;
4mod stream;
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::Weak;
11use std::time::Duration;
12
13use futures::{SinkExt, StreamExt};
14use tokio::sync::OwnedSemaphorePermit;
15use tokio::sync::{Semaphore, broadcast};
16use tokio::task::JoinHandle;
17use tokio_util::codec::Framed;
18use tokio_util::sync::CancellationToken;
19#[cfg(feature = "tracing")]
20use tracing::Instrument;
21
22use ombrac::codec;
23use ombrac::protocol;
24use ombrac_macros::{debug, error, warn};
25use ombrac_transport::{Acceptor, Connection};
26
27use crate::config::ConnectionConfig;
28
29/// Processes a single client connection, handling authentication and tunnel management.
30///
31/// This struct manages the lifecycle of a client connection after it has been
32/// accepted by the server. It performs authentication, sets up tunnel handlers
33/// for streams and datagrams, and manages the connection until it closes.
34pub struct ClientConnectionProcessor<C: Connection> {
35    transport_connection: Arc<C>,
36    shutdown_token: CancellationToken,
37}
38
39impl<C: Connection> ClientConnectionProcessor<C> {
40    /// Handles a new client connection from authentication through tunnel setup.
41    ///
42    /// This method:
43    /// 1. Performs the authentication
44    /// 2. Notifies the authenticator that the connection is accepted
45    /// 3. Sets up and runs tunnel loops for streams and datagrams
46    pub async fn handle<A>(
47        connection: C,
48        authenticator: &A,
49        config: Arc<ConnectionConfig>,
50    ) -> io::Result<()>
51    where
52        A: Authenticator<C>,
53    {
54        let (auth_context, connection) =
55            Self::perform_authentication(connection, authenticator, &config).await?;
56
57        let transport_connection = Arc::new(connection);
58
59        authenticator
60            .accept(
61                auth_context,
62                ConnectionHandle {
63                    inner: transport_connection.clone(),
64                },
65            )
66            .await;
67
68        let processor = Self {
69            transport_connection,
70            shutdown_token: CancellationToken::new(),
71        };
72
73        processor.run_tunnel_loops().await;
74
75        Ok(())
76    }
77
78    async fn perform_authentication<A: Authenticator<C>>(
79        connection: C,
80        authenticator: &A,
81        config: &ConnectionConfig,
82    ) -> io::Result<(A::AuthContext, C)> {
83        let auth_timeout = Duration::from_secs(config.auth_timeout_secs());
84
85        // Accept control stream
86        let mut control_stream = connection.accept_bidirectional().await.map_err(|e| {
87            io::Error::other(format!("failed to accept bidirectional stream: {}", e))
88        })?;
89        let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
90
91        // Read and parse hello message
92        let hello = Self::read_hello_message(&mut control_frame, auth_timeout).await?;
93
94        #[cfg(feature = "tracing")]
95        Self::trace_auth(&hello);
96
97        // Verify authentication
98        let auth_context =
99            Self::verify_authentication(&hello, authenticator, auth_timeout, &mut control_frame)
100                .await?;
101
102        Ok((auth_context, connection))
103    }
104
105    /// Reads and parses the hello message from the client.
106    async fn read_hello_message(
107        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
108        timeout: Duration,
109    ) -> io::Result<protocol::ClientHello>
110    where
111        C: Connection,
112    {
113        // Read payload with timeout
114        let payload = tokio::time::timeout(timeout, control_frame.next())
115            .await
116            .map_err(|_| {
117                io::Error::new(
118                    io::ErrorKind::TimedOut,
119                    format!(
120                        "authentication timeout: failed to receive hello message within {:?}",
121                        timeout
122                    ),
123                )
124            })?
125            .ok_or_else(|| {
126                io::Error::new(io::ErrorKind::UnexpectedEof, "stream closed before hello")
127            })??;
128
129        // Decode message
130        let message: codec::ClientMessage = protocol::decode(&payload).map_err(|e| {
131            io::Error::new(
132                io::ErrorKind::InvalidData,
133                format!("failed to decode client message: {}", e),
134            )
135        })?;
136
137        // Extract hello message
138        match message {
139            codec::ClientMessage::Hello(hello) => Ok(hello),
140            _ => {
141                // Invalid message type - disconnect with random delay
142                let stream = control_frame.get_mut();
143                Self::disconnect_with_random_delay(*stream).await;
144                Err(io::Error::new(
145                    io::ErrorKind::InvalidData,
146                    "authentication failed: invalid message type (expected Hello)",
147                ))
148            }
149        }
150    }
151
152    /// Verifies authentication and sends response.
153    async fn verify_authentication<A: Authenticator<C>>(
154        hello: &protocol::ClientHello,
155        authenticator: &A,
156        timeout: Duration,
157        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
158    ) -> io::Result<A::AuthContext>
159    where
160        C: Connection,
161    {
162        // Check protocol version
163        if hello.version != protocol::PROTOCOL_VERSION {
164            Self::handle_auth_failure(control_frame).await;
165            return Err(io::Error::new(
166                io::ErrorKind::PermissionDenied,
167                "incompatible version",
168            ));
169        }
170
171        // Perform authentication with timeout
172        let auth_context = tokio::time::timeout(timeout, authenticator.verify(hello)).await??;
173
174        Self::send_auth_ok_response(control_frame, timeout).await?;
175
176        Ok(auth_context)
177    }
178
179    /// Sends authentication response with timeout.
180    async fn send_auth_ok_response(
181        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
182        timeout: Duration,
183    ) -> io::Result<()>
184    where
185        C: Connection,
186    {
187        tokio::time::timeout(
188            timeout,
189            control_frame.send(protocol::encode(&protocol::ServerAuthResponse::Ok)?),
190        )
191        .await
192        .map_err(|_| {
193            io::Error::new(
194                io::ErrorKind::TimedOut,
195                format!(
196                    "authentication timeout: failed to send response within {:?}",
197                    timeout
198                ),
199            )
200        })??;
201        Ok(())
202    }
203
204    /// Handles authentication failure by disconnecting with random delay.
205    async fn handle_auth_failure(
206        control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
207    ) where
208        C: Connection,
209    {
210        // Get the underlying stream for disconnection
211        let stream = control_frame.get_mut();
212        Self::disconnect_with_random_delay(*stream).await;
213    }
214
215    /// Disconnects the stream with a random delay to prevent timing attacks.
216    ///
217    /// This function introduces a random delay (100-500ms) before closing the stream,
218    /// making it harder for attackers to distinguish between different failure modes
219    /// based on response timing.
220    async fn disconnect_with_random_delay(stream: &mut C::Stream) {
221        use rand::Rng;
222
223        let delay_ms = {
224            let mut rng = rand::rng();
225            rng.random_range(100..=500)
226        };
227
228        tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
229        let _ = tokio::io::AsyncWriteExt::shutdown(stream).await;
230    }
231
232    /// Runs the tunnel loops for streams and datagrams until the connection closes.
233    async fn run_tunnel_loops(&self) {
234        let stream_tunnel_handle = self.spawn_stream_tunnel();
235        #[cfg(feature = "datagram")]
236        let datagram_tunnel_handle = self.spawn_datagram_tunnel();
237
238        #[cfg(not(feature = "datagram"))]
239        let result = stream_tunnel_handle.await;
240
241        #[cfg(feature = "datagram")]
242        let result = tokio::select! {
243            res = stream_tunnel_handle => res,
244            res = datagram_tunnel_handle => res,
245        };
246
247        // Signal all related tasks to shut down
248        self.shutdown_token.cancel();
249
250        match result {
251            Ok(Ok(_)) => debug!("connection closed gracefully"),
252            Ok(Err(e)) => debug!("connection closed with internal error: {}", e),
253            Err(e) => warn!("tunnel handler task panicked or failed: {}", e),
254        }
255    }
256
257    fn spawn_stream_tunnel(&self) -> JoinHandle<io::Result<()>> {
258        use crate::connection::stream::StreamTunnel;
259
260        let connection = Arc::clone(&self.transport_connection);
261        let shutdown = self.shutdown_token.child_token();
262        let tunnel = StreamTunnel::new(connection, shutdown);
263
264        #[cfg(not(feature = "tracing"))]
265        let handle = tokio::spawn(tunnel.accept_loop());
266        #[cfg(feature = "tracing")]
267        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
268
269        handle
270    }
271
272    #[cfg(feature = "datagram")]
273    fn spawn_datagram_tunnel(&self) -> JoinHandle<io::Result<()>> {
274        use crate::connection::datagram::DatagramTunnel;
275
276        let connection = Arc::clone(&self.transport_connection);
277        let shutdown = self.shutdown_token.child_token();
278        let tunnel = DatagramTunnel::new(connection, shutdown);
279
280        #[cfg(not(feature = "tracing"))]
281        let handle = tokio::spawn(tunnel.accept_loop());
282        #[cfg(feature = "tracing")]
283        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
284
285        handle
286    }
287
288    #[cfg(feature = "tracing")]
289    fn trace_auth(hello: &protocol::ClientHello) {
290        use std::io::Write;
291
292        let mut buf = [0u8; 6];
293        let mut cursor = std::io::Cursor::new(&mut buf[..]);
294
295        for byte in hello.secret.iter().take(3) {
296            let _ = write!(cursor, "{:02x}", byte);
297        }
298
299        if let Ok(hex_str) = std::str::from_utf8(&buf) {
300            tracing::Span::current().record("secret", hex_str);
301        }
302    }
303}
304
305/// ConnectionAcceptor manages incoming connections with resource limits and lifecycle control.
306///
307/// This struct handles accepting new connections from the transport layer,
308/// limiting concurrent connections, and delegating each connection to a processor.
309///
310/// Generic parameters:
311/// - `T`: The acceptor type that accepts new connections from the transport
312/// - `A`: The authenticator type that handles connection authentication
313pub struct ConnectionAcceptor<T, A> {
314    acceptor: Arc<T>,
315    authenticator: Arc<A>,
316    connection_semaphore: Arc<Semaphore>,
317    config: Arc<ConnectionConfig>,
318}
319
320impl<T: Acceptor, A: Authenticator<T::Connection> + 'static> ConnectionAcceptor<T, A> {
321    /// Creates a new connection acceptor with the given acceptor and authenticator.
322    ///
323    /// The acceptor will use default connection configuration if not provided.
324    pub fn new(acceptor: T, authenticator: A) -> Self {
325        Self::with_config(
326            acceptor,
327            authenticator,
328            Arc::new(ConnectionConfig::default()),
329        )
330    }
331
332    /// Creates a new connection acceptor with custom connection configuration.
333    pub fn with_config(acceptor: T, authenticator: A, config: Arc<ConnectionConfig>) -> Self {
334        let max_connections = config.max_connections();
335        Self {
336            acceptor: Arc::new(acceptor),
337            authenticator: Arc::new(authenticator),
338            connection_semaphore: Arc::new(Semaphore::new(max_connections)),
339            config,
340        }
341    }
342
343    /// Main accept loop that accepts incoming connections and manages them with resource limits.
344    ///
345    /// This method will:
346    /// 1. Accept new connections from the acceptor
347    /// 2. Acquire a semaphore permit to limit total concurrent connections
348    /// 3. Spawn a task to handle each connection
349    /// 4. Gracefully handle shutdown signals
350    pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
351        loop {
352            tokio::select! {
353                _ = shutdown_rx.recv() => {
354                    break;
355                },
356                accepted = self.acceptor.accept() => {
357                    Self::handle_incoming_connection(
358                        accepted,
359                        Arc::clone(&self.authenticator),
360                        Arc::clone(&self.connection_semaphore),
361                        Arc::clone(&self.config),
362                    );
363                },
364            }
365        }
366
367        Ok(())
368    }
369
370    /// Handles an incoming connection, either spawning a processor or rejecting it.
371    fn handle_incoming_connection(
372        result: io::Result<<T as Acceptor>::Connection>,
373        authenticator: Arc<A>,
374        semaphore: Arc<Semaphore>,
375        config: Arc<ConnectionConfig>,
376    ) {
377        match result {
378            Ok(connection) => match semaphore.try_acquire_owned() {
379                Ok(permit) => {
380                    #[cfg(not(feature = "tracing"))]
381                    tokio::spawn(Self::process_connection_with_permit(
382                        connection,
383                        authenticator,
384                        permit,
385                        config,
386                    ));
387                    #[cfg(feature = "tracing")]
388                    tokio::spawn(
389                        Self::process_connection_with_permit(
390                            connection,
391                            authenticator,
392                            permit,
393                            config,
394                        )
395                        .in_current_span(),
396                    );
397                }
398                Err(_) => {
399                    warn!(
400                        "connection rejected: maximum concurrent connections ({}) reached",
401                        config.max_connections()
402                    );
403                }
404            },
405            Err(err) => {
406                error!("failed to accept connection: {}", err);
407            }
408        }
409    }
410
411    /// Processes a connection with a semaphore permit.
412    ///
413    /// The permit is automatically released when the connection is closed.
414    async fn process_connection_with_permit(
415        connection: <T as Acceptor>::Connection,
416        authenticator: Arc<A>,
417        _permit: OwnedSemaphorePermit,
418        config: Arc<ConnectionConfig>,
419    ) {
420        // Permit is held for the lifetime of this function
421        Self::process_connection(connection, authenticator, config).await;
422        // Permit is automatically released when dropped
423    }
424
425    #[cfg_attr(feature = "tracing",
426        tracing::instrument(
427            name = "connection",
428            skip_all,
429            fields(
430                id = connection.id(),
431                from = tracing::field::Empty,
432                secret = tracing::field::Empty,
433                reason = tracing::field::Empty
434            )
435        )
436    )]
437    async fn process_connection(
438        connection: <T as Acceptor>::Connection,
439        authenticator: Arc<A>,
440        config: Arc<ConnectionConfig>,
441    ) {
442        #[cfg(feature = "tracing")]
443        if let Ok(addr) = connection.remote_address() {
444            tracing::Span::current().record("from", tracing::field::display(addr));
445        }
446
447        let _result =
448            ClientConnectionProcessor::handle(connection, authenticator.as_ref(), config).await;
449
450        #[cfg(feature = "tracing")]
451        match _result {
452            Ok(_) => {
453                tracing::Span::current().record("reason", "ok");
454                tracing::info!("connection closed");
455            }
456            Err(e) => {
457                tracing::Span::current().record("reason", tracing::field::display(&e));
458                tracing::error!(error = %e, "connection closed with error");
459            }
460        }
461    }
462
463    pub fn local_addr(&self) -> io::Result<SocketAddr> {
464        self.acceptor.local_addr()
465    }
466}
467
468pub struct ConnectionHandle<C> {
469    inner: Arc<C>,
470}
471
472impl<C: Connection> ConnectionHandle<C> {
473    pub fn downgrade_inner(&self) -> Weak<C> {
474        Arc::downgrade(&self.inner)
475    }
476
477    pub fn close(&self, error_code: u32, reason: &[u8]) {
478        self.inner.close(error_code, reason);
479    }
480}
481
482/// Authentication error types returned by the server during authentication.
483#[derive(Debug, Clone, PartialEq, Eq)]
484pub enum ConnectionAuthError {
485    /// Client protocol version is not supported by the server.
486    IncompatibleVersion,
487    /// Authentication secret is invalid or incorrect.
488    InvalidSecret,
489    /// Internal server error during authentication processing.
490    ServerError,
491    /// Other error
492    Other(String),
493}
494
495impl From<ConnectionAuthError> for io::Error {
496    fn from(value: ConnectionAuthError) -> Self {
497        match value {
498            ConnectionAuthError::IncompatibleVersion => {
499                io::Error::new(io::ErrorKind::Unsupported, "incompatible protocol version")
500            }
501            ConnectionAuthError::InvalidSecret => io::Error::new(
502                io::ErrorKind::PermissionDenied,
503                "invalid authentication secret",
504            ),
505            ConnectionAuthError::ServerError => io::Error::new(
506                io::ErrorKind::ConnectionAborted,
507                "internal server error during auth",
508            ),
509            ConnectionAuthError::Other(msg) => io::Error::other(msg),
510        }
511    }
512}
513
514/// Authenticator trait for verifying and accepting client connections.
515///
516/// This trait provides authentication logic for incoming connections.
517/// Implementations should verify the client's credentials and optionally
518/// perform any setup needed when a connection is accepted.
519pub trait Authenticator<T>: Send + Sync {
520    /// Context type that can be passed from verification to acceptance.
521    type AuthContext: Send;
522
523    /// Verifies the client's hello message and returns an authentication context.
524    ///
525    /// This method is called during the authentication phase to verify the client's
526    /// credentials. If verification succeeds, it returns a context that will
527    /// be passed to `accept`.
528    fn verify(
529        &self,
530        hello: &protocol::ClientHello,
531    ) -> impl Future<Output = Result<Self::AuthContext, ConnectionAuthError>> + Send;
532
533    /// Called after successful authentication to handle the accepted connection.
534    ///
535    /// This method is called with the authentication context from `verify` and
536    /// a handle to the connection. Implementations can use this to perform any
537    /// additional setup or logging.
538    fn accept(
539        &self,
540        auth_context: Self::AuthContext,
541        connection: ConnectionHandle<T>,
542    ) -> impl Future<Output = ()> + Send;
543}
544
545impl<T: Send + Sync> Authenticator<T> for ombrac::protocol::Secret {
546    type AuthContext = ();
547
548    async fn verify(&self, hello: &protocol::ClientHello) -> Result<(), ConnectionAuthError> {
549        if &hello.secret == self {
550            Ok(())
551        } else {
552            Err(ConnectionAuthError::InvalidSecret)
553        }
554    }
555
556    async fn accept(&self, _auth_context: Self::AuthContext, _connection: ConnectionHandle<T>) {}
557}