1#[cfg(feature = "datagram")]
2mod datagram;
3mod dns;
4mod stream;
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::Weak;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13
14use futures::{SinkExt, StreamExt};
15use tokio::sync::OwnedSemaphorePermit;
16use tokio::sync::{Semaphore, broadcast};
17use tokio::task::JoinHandle;
18use tokio_util::codec::Framed;
19use tokio_util::sync::CancellationToken;
20#[cfg(feature = "tracing")]
21use tracing::Instrument;
22
23use ombrac::codec;
24use ombrac::metrics::Metrics;
25use ombrac::protocol;
26use ombrac_macros::{debug, error, warn};
27use ombrac_transport::{Acceptor, Connection};
28
29use crate::config::ConnectionConfig;
30
31pub struct ClientConnectionProcessor<C: Connection> {
37 transport_connection: Arc<C>,
38 shutdown_token: CancellationToken,
39 metrics: Metrics,
40}
41
42impl<C: Connection> ClientConnectionProcessor<C> {
43 pub async fn handle<A>(
50 connection: C,
51 authenticator: &A,
52 config: Arc<ConnectionConfig>,
53 metrics: &Metrics,
54 ) -> io::Result<()>
55 where
56 A: Authenticator<C>,
57 {
58 let (auth_context, connection) =
59 Self::perform_authentication(connection, authenticator, &config).await?;
60
61 let transport_connection = Arc::new(connection);
62
63 authenticator
64 .accept(
65 auth_context,
66 ConnectionHandle {
67 inner: transport_connection.clone(),
68 },
69 )
70 .await;
71
72 let processor = Self {
73 transport_connection,
74 shutdown_token: CancellationToken::new(),
75 metrics: metrics.clone(),
76 };
77
78 processor.run_tunnel_loops().await;
79
80 Ok(())
81 }
82
83 async fn perform_authentication<A: Authenticator<C>>(
84 connection: C,
85 authenticator: &A,
86 config: &ConnectionConfig,
87 ) -> io::Result<(A::AuthContext, C)> {
88 let auth_timeout = Duration::from_secs(config.auth_timeout_secs());
89
90 let mut control_stream = connection.accept_bidirectional().await.map_err(|e| {
92 io::Error::other(format!("failed to accept bidirectional stream: {}", e))
93 })?;
94 let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
95
96 let hello = Self::read_hello_message(&mut control_frame, auth_timeout).await?;
98
99 #[cfg(feature = "tracing")]
100 Self::trace_auth(&hello);
101
102 let auth_context =
104 Self::verify_authentication(&hello, authenticator, auth_timeout, &mut control_frame)
105 .await?;
106
107 Ok((auth_context, connection))
108 }
109
110 async fn read_hello_message(
112 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
113 timeout: Duration,
114 ) -> io::Result<protocol::ClientHello>
115 where
116 C: Connection,
117 {
118 let payload = tokio::time::timeout(timeout, control_frame.next())
120 .await
121 .map_err(|_| {
122 io::Error::new(
123 io::ErrorKind::TimedOut,
124 format!(
125 "authentication timeout: failed to receive hello message within {:?}",
126 timeout
127 ),
128 )
129 })?
130 .ok_or_else(|| {
131 io::Error::new(io::ErrorKind::UnexpectedEof, "stream closed before hello")
132 })??;
133
134 let message: codec::ClientMessage = protocol::decode(&payload).map_err(|e| {
136 io::Error::new(
137 io::ErrorKind::InvalidData,
138 format!("failed to decode client message: {}", e),
139 )
140 })?;
141
142 match message {
144 codec::ClientMessage::Hello(hello) => Ok(hello),
145 _ => {
146 let stream = control_frame.get_mut();
148 Self::disconnect_with_random_delay(*stream).await;
149 Err(io::Error::new(
150 io::ErrorKind::InvalidData,
151 "authentication failed: invalid message type (expected Hello)",
152 ))
153 }
154 }
155 }
156
157 async fn verify_authentication<A: Authenticator<C>>(
159 hello: &protocol::ClientHello,
160 authenticator: &A,
161 timeout: Duration,
162 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
163 ) -> io::Result<A::AuthContext>
164 where
165 C: Connection,
166 {
167 if hello.version != protocol::PROTOCOL_VERSION {
169 Self::handle_auth_failure(control_frame).await;
170 return Err(io::Error::new(
171 io::ErrorKind::PermissionDenied,
172 "incompatible version",
173 ));
174 }
175
176 let auth_context = tokio::time::timeout(timeout, authenticator.verify(hello)).await??;
178
179 Self::send_auth_ok_response(control_frame, timeout).await?;
180
181 Ok(auth_context)
182 }
183
184 async fn send_auth_ok_response(
186 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
187 timeout: Duration,
188 ) -> io::Result<()>
189 where
190 C: Connection,
191 {
192 tokio::time::timeout(
193 timeout,
194 control_frame.send(protocol::encode(&protocol::ServerAuthResponse::Ok)?),
195 )
196 .await
197 .map_err(|_| {
198 io::Error::new(
199 io::ErrorKind::TimedOut,
200 format!(
201 "authentication timeout: failed to send response within {:?}",
202 timeout
203 ),
204 )
205 })??;
206 Ok(())
207 }
208
209 async fn handle_auth_failure(
211 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
212 ) where
213 C: Connection,
214 {
215 let stream = control_frame.get_mut();
217 Self::disconnect_with_random_delay(*stream).await;
218 }
219
220 async fn disconnect_with_random_delay(stream: &mut C::Stream) {
226 use rand::RngExt;
227
228 let delay_ms = {
229 let mut rng = rand::rng();
230 rng.random_range(100..=500)
231 };
232
233 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
234 let _ = tokio::io::AsyncWriteExt::shutdown(stream).await;
235 }
236
237 async fn run_tunnel_loops(&self) {
239 let stream_tunnel_handle = self.spawn_stream_tunnel();
240 #[cfg(feature = "datagram")]
241 let datagram_tunnel_handle = self.spawn_datagram_tunnel();
242
243 #[cfg(not(feature = "datagram"))]
244 let result = stream_tunnel_handle.await;
245
246 #[cfg(feature = "datagram")]
247 let result = tokio::select! {
248 res = stream_tunnel_handle => res,
249 res = datagram_tunnel_handle => res,
250 };
251
252 self.shutdown_token.cancel();
254
255 match result {
256 Ok(Ok(_)) => debug!("connection closed gracefully"),
257 Ok(Err(e)) => debug!("connection closed with internal error: {}", e),
258 Err(e) => warn!("tunnel handler task panicked or failed: {}", e),
259 }
260 }
261
262 fn spawn_stream_tunnel(&self) -> JoinHandle<io::Result<()>> {
263 use crate::connection::stream::StreamTunnel;
264
265 let connection = Arc::clone(&self.transport_connection);
266 let shutdown = self.shutdown_token.child_token();
267 let tunnel = StreamTunnel::new(connection, shutdown, self.metrics.clone());
268
269 #[cfg(not(feature = "tracing"))]
270 let handle = tokio::spawn(tunnel.accept_loop());
271 #[cfg(feature = "tracing")]
272 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
273
274 handle
275 }
276
277 #[cfg(feature = "datagram")]
278 fn spawn_datagram_tunnel(&self) -> JoinHandle<io::Result<()>> {
279 use crate::connection::datagram::DatagramTunnel;
280
281 let connection = Arc::clone(&self.transport_connection);
282 let shutdown = self.shutdown_token.child_token();
283 let tunnel = DatagramTunnel::new(connection, shutdown, self.metrics.clone());
284
285 #[cfg(not(feature = "tracing"))]
286 let handle = tokio::spawn(tunnel.accept_loop());
287 #[cfg(feature = "tracing")]
288 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
289
290 handle
291 }
292
293 #[cfg(feature = "tracing")]
294 fn trace_auth(hello: &protocol::ClientHello) {
295 use std::io::Write;
296
297 let mut buf = [0u8; 6];
298 let mut cursor = std::io::Cursor::new(&mut buf[..]);
299
300 for byte in hello.secret.iter().take(3) {
301 let _ = write!(cursor, "{:02x}", byte);
302 }
303
304 if let Ok(hex_str) = std::str::from_utf8(&buf) {
305 tracing::Span::current().record("secret", hex_str);
306 }
307 }
308}
309
310pub struct ConnectionAcceptor<T, A> {
319 acceptor: Arc<T>,
320 authenticator: Arc<A>,
321 connection_semaphore: Arc<Semaphore>,
322 config: Arc<ConnectionConfig>,
323 metrics: Metrics,
324}
325
326impl<T: Acceptor, A: Authenticator<T::Connection> + 'static> ConnectionAcceptor<T, A> {
327 pub fn new(acceptor: T, authenticator: A) -> Self {
331 Self::with_config(
332 acceptor,
333 authenticator,
334 Arc::new(ConnectionConfig::default()),
335 )
336 }
337
338 pub fn with_config(acceptor: T, authenticator: A, config: Arc<ConnectionConfig>) -> Self {
340 let max_connections = config.max_connections();
341 Self {
342 acceptor: Arc::new(acceptor),
343 authenticator: Arc::new(authenticator),
344 connection_semaphore: Arc::new(Semaphore::new(max_connections)),
345 config,
346 metrics: Metrics::new(),
347 }
348 }
349
350 pub fn metrics(&self) -> Metrics {
355 self.metrics.clone()
356 }
357
358 pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
366 loop {
367 tokio::select! {
368 _ = shutdown_rx.recv() => {
369 break;
370 },
371 accepted = self.acceptor.accept() => {
372 Self::handle_incoming_connection(
373 accepted,
374 Arc::clone(&self.authenticator),
375 Arc::clone(&self.connection_semaphore),
376 Arc::clone(&self.config),
377 self.metrics.clone(),
378 );
379 },
380 }
381 }
382
383 Ok(())
384 }
385
386 fn handle_incoming_connection(
388 result: io::Result<<T as Acceptor>::Connection>,
389 authenticator: Arc<A>,
390 semaphore: Arc<Semaphore>,
391 config: Arc<ConnectionConfig>,
392 metrics: Metrics,
393 ) {
394 match result {
395 Ok(connection) => match semaphore.try_acquire_owned() {
396 Ok(permit) => {
397 metrics
398 .counters()
399 .connections_accepted
400 .fetch_add(1, Ordering::Relaxed);
401 #[cfg(not(feature = "tracing"))]
402 tokio::spawn(Self::process_connection_with_permit(
403 connection,
404 authenticator,
405 permit,
406 config,
407 metrics,
408 ));
409 #[cfg(feature = "tracing")]
410 tokio::spawn(
411 Self::process_connection_with_permit(
412 connection,
413 authenticator,
414 permit,
415 config,
416 metrics,
417 )
418 .in_current_span(),
419 );
420 }
421 Err(_) => {
422 metrics
423 .counters()
424 .connections_rejected
425 .fetch_add(1, Ordering::Relaxed);
426 warn!(
427 "connection rejected: maximum concurrent connections ({}) reached",
428 config.max_connections()
429 );
430 }
431 },
432 Err(err) => {
433 error!("failed to accept connection: {}", err);
434 }
435 }
436 }
437
438 async fn process_connection_with_permit(
442 connection: <T as Acceptor>::Connection,
443 authenticator: Arc<A>,
444 _permit: OwnedSemaphorePermit,
445 config: Arc<ConnectionConfig>,
446 metrics: Metrics,
447 ) {
448 Self::process_connection(connection, authenticator, config, metrics).await;
450 }
452
453 #[cfg_attr(feature = "tracing",
454 tracing::instrument(
455 name = "connection",
456 skip_all,
457 fields(
458 id = connection.id(),
459 from = tracing::field::Empty,
460 secret = tracing::field::Empty,
461 reason = tracing::field::Empty
462 )
463 )
464 )]
465 async fn process_connection(
466 connection: <T as Acceptor>::Connection,
467 authenticator: Arc<A>,
468 config: Arc<ConnectionConfig>,
469 metrics: Metrics,
470 ) {
471 #[cfg(feature = "tracing")]
472 if let Ok(addr) = connection.remote_address() {
473 tracing::Span::current().record("from", tracing::field::display(addr));
474 }
475
476 let _result =
477 ClientConnectionProcessor::handle(connection, authenticator.as_ref(), config, &metrics)
478 .await;
479
480 if _result.is_err() {
481 metrics
482 .counters()
483 .connections_auth_failed
484 .fetch_add(1, Ordering::Relaxed);
485 }
486
487 #[cfg(feature = "tracing")]
488 match _result {
489 Ok(_) => {
490 tracing::Span::current().record("reason", "ok");
491 tracing::info!("connection closed");
492 }
493 Err(e) => {
494 tracing::Span::current().record("reason", tracing::field::display(&e));
495 tracing::error!(error = %e, "connection closed with error");
496 }
497 }
498 }
499
500 pub fn local_addr(&self) -> io::Result<SocketAddr> {
501 self.acceptor.local_addr()
502 }
503}
504
505pub struct ConnectionHandle<C> {
506 inner: Arc<C>,
507}
508
509impl<C: Connection> ConnectionHandle<C> {
510 pub fn downgrade_inner(&self) -> Weak<C> {
511 Arc::downgrade(&self.inner)
512 }
513
514 pub fn close(&self, error_code: u32, reason: &[u8]) {
515 self.inner.close(error_code, reason);
516 }
517}
518
519#[derive(Debug, Clone, PartialEq, Eq)]
521pub enum ConnectionAuthError {
522 IncompatibleVersion,
524 InvalidSecret,
526 ServerError,
528 Other(String),
530}
531
532impl From<ConnectionAuthError> for io::Error {
533 fn from(value: ConnectionAuthError) -> Self {
534 match value {
535 ConnectionAuthError::IncompatibleVersion => {
536 io::Error::new(io::ErrorKind::Unsupported, "incompatible protocol version")
537 }
538 ConnectionAuthError::InvalidSecret => io::Error::new(
539 io::ErrorKind::PermissionDenied,
540 "invalid authentication secret",
541 ),
542 ConnectionAuthError::ServerError => io::Error::new(
543 io::ErrorKind::ConnectionAborted,
544 "internal server error during auth",
545 ),
546 ConnectionAuthError::Other(msg) => io::Error::other(msg),
547 }
548 }
549}
550
551pub trait Authenticator<T>: Send + Sync {
557 type AuthContext: Send;
559
560 fn verify(
566 &self,
567 hello: &protocol::ClientHello,
568 ) -> impl Future<Output = Result<Self::AuthContext, ConnectionAuthError>> + Send;
569
570 fn accept(
576 &self,
577 auth_context: Self::AuthContext,
578 connection: ConnectionHandle<T>,
579 ) -> impl Future<Output = ()> + Send;
580}
581
582impl<T: Send + Sync> Authenticator<T> for ombrac::protocol::Secret {
583 type AuthContext = ();
584
585 async fn verify(&self, hello: &protocol::ClientHello) -> Result<(), ConnectionAuthError> {
586 if &hello.secret == self {
587 Ok(())
588 } else {
589 Err(ConnectionAuthError::InvalidSecret)
590 }
591 }
592
593 async fn accept(&self, _auth_context: Self::AuthContext, _connection: ConnectionHandle<T>) {}
594}