1#[cfg(feature = "datagram")]
2mod datagram;
3mod dns;
4mod stream;
5
6use std::future::Future;
7use std::io;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::Weak;
11use std::time::Duration;
12
13use futures::{SinkExt, StreamExt};
14use tokio::sync::OwnedSemaphorePermit;
15use tokio::sync::{Semaphore, broadcast};
16use tokio::task::JoinHandle;
17use tokio_util::codec::Framed;
18use tokio_util::sync::CancellationToken;
19#[cfg(feature = "tracing")]
20use tracing::Instrument;
21
22use ombrac::codec;
23use ombrac::protocol;
24use ombrac_macros::{debug, error, warn};
25use ombrac_transport::{Acceptor, Connection};
26
27use crate::config::ConnectionConfig;
28
29pub struct ClientConnectionProcessor<C: Connection> {
35 transport_connection: Arc<C>,
36 shutdown_token: CancellationToken,
37}
38
39impl<C: Connection> ClientConnectionProcessor<C> {
40 pub async fn handle<A>(
47 connection: C,
48 authenticator: &A,
49 config: Arc<ConnectionConfig>,
50 ) -> io::Result<()>
51 where
52 A: Authenticator<C>,
53 {
54 let (auth_context, connection) =
55 Self::perform_authentication(connection, authenticator, &config).await?;
56
57 let transport_connection = Arc::new(connection);
58
59 authenticator
60 .accept(
61 auth_context,
62 ConnectionHandle {
63 inner: transport_connection.clone(),
64 },
65 )
66 .await;
67
68 let processor = Self {
69 transport_connection,
70 shutdown_token: CancellationToken::new(),
71 };
72
73 processor.run_tunnel_loops().await;
74
75 Ok(())
76 }
77
78 async fn perform_authentication<A: Authenticator<C>>(
79 connection: C,
80 authenticator: &A,
81 config: &ConnectionConfig,
82 ) -> io::Result<(A::AuthContext, C)> {
83 let auth_timeout = Duration::from_secs(config.auth_timeout_secs());
84
85 let mut control_stream = connection.accept_bidirectional().await.map_err(|e| {
87 io::Error::other(format!("failed to accept bidirectional stream: {}", e))
88 })?;
89 let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
90
91 let hello = Self::read_hello_message(&mut control_frame, auth_timeout).await?;
93
94 #[cfg(feature = "tracing")]
95 Self::trace_auth(&hello);
96
97 let auth_context =
99 Self::verify_authentication(&hello, authenticator, auth_timeout, &mut control_frame)
100 .await?;
101
102 Ok((auth_context, connection))
103 }
104
105 async fn read_hello_message(
107 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
108 timeout: Duration,
109 ) -> io::Result<protocol::ClientHello>
110 where
111 C: Connection,
112 {
113 let payload = tokio::time::timeout(timeout, control_frame.next())
115 .await
116 .map_err(|_| {
117 io::Error::new(
118 io::ErrorKind::TimedOut,
119 format!(
120 "authentication timeout: failed to receive hello message within {:?}",
121 timeout
122 ),
123 )
124 })?
125 .ok_or_else(|| {
126 io::Error::new(io::ErrorKind::UnexpectedEof, "stream closed before hello")
127 })??;
128
129 let message: codec::ClientMessage = protocol::decode(&payload).map_err(|e| {
131 io::Error::new(
132 io::ErrorKind::InvalidData,
133 format!("failed to decode client message: {}", e),
134 )
135 })?;
136
137 match message {
139 codec::ClientMessage::Hello(hello) => Ok(hello),
140 _ => {
141 let stream = control_frame.get_mut();
143 Self::disconnect_with_random_delay(*stream).await;
144 Err(io::Error::new(
145 io::ErrorKind::InvalidData,
146 "authentication failed: invalid message type (expected Hello)",
147 ))
148 }
149 }
150 }
151
152 async fn verify_authentication<A: Authenticator<C>>(
154 hello: &protocol::ClientHello,
155 authenticator: &A,
156 timeout: Duration,
157 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
158 ) -> io::Result<A::AuthContext>
159 where
160 C: Connection,
161 {
162 if hello.version != protocol::PROTOCOL_VERSION {
164 Self::handle_auth_failure(control_frame).await;
165 return Err(io::Error::new(
166 io::ErrorKind::PermissionDenied,
167 "incompatible version",
168 ));
169 }
170
171 let auth_context = tokio::time::timeout(timeout, authenticator.verify(hello)).await??;
173
174 Self::send_auth_ok_response(control_frame, timeout).await?;
175
176 Ok(auth_context)
177 }
178
179 async fn send_auth_ok_response(
181 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
182 timeout: Duration,
183 ) -> io::Result<()>
184 where
185 C: Connection,
186 {
187 tokio::time::timeout(
188 timeout,
189 control_frame.send(protocol::encode(&protocol::ServerAuthResponse::Ok)?),
190 )
191 .await
192 .map_err(|_| {
193 io::Error::new(
194 io::ErrorKind::TimedOut,
195 format!(
196 "authentication timeout: failed to send response within {:?}",
197 timeout
198 ),
199 )
200 })??;
201 Ok(())
202 }
203
204 async fn handle_auth_failure(
206 control_frame: &mut Framed<&mut <C as Connection>::Stream, codec::LengthDelimitedCodec>,
207 ) where
208 C: Connection,
209 {
210 let stream = control_frame.get_mut();
212 Self::disconnect_with_random_delay(*stream).await;
213 }
214
215 async fn disconnect_with_random_delay(stream: &mut C::Stream) {
221 use rand::Rng;
222
223 let delay_ms = {
224 let mut rng = rand::rng();
225 rng.random_range(100..=500)
226 };
227
228 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
229 let _ = tokio::io::AsyncWriteExt::shutdown(stream).await;
230 }
231
232 async fn run_tunnel_loops(&self) {
234 let stream_tunnel_handle = self.spawn_stream_tunnel();
235 #[cfg(feature = "datagram")]
236 let datagram_tunnel_handle = self.spawn_datagram_tunnel();
237
238 #[cfg(not(feature = "datagram"))]
239 let result = stream_tunnel_handle.await;
240
241 #[cfg(feature = "datagram")]
242 let result = tokio::select! {
243 res = stream_tunnel_handle => res,
244 res = datagram_tunnel_handle => res,
245 };
246
247 self.shutdown_token.cancel();
249
250 match result {
251 Ok(Ok(_)) => debug!("connection closed gracefully"),
252 Ok(Err(e)) => debug!("connection closed with internal error: {}", e),
253 Err(e) => warn!("tunnel handler task panicked or failed: {}", e),
254 }
255 }
256
257 fn spawn_stream_tunnel(&self) -> JoinHandle<io::Result<()>> {
258 use crate::connection::stream::StreamTunnel;
259
260 let connection = Arc::clone(&self.transport_connection);
261 let shutdown = self.shutdown_token.child_token();
262 let tunnel = StreamTunnel::new(connection, shutdown);
263
264 #[cfg(not(feature = "tracing"))]
265 let handle = tokio::spawn(tunnel.accept_loop());
266 #[cfg(feature = "tracing")]
267 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
268
269 handle
270 }
271
272 #[cfg(feature = "datagram")]
273 fn spawn_datagram_tunnel(&self) -> JoinHandle<io::Result<()>> {
274 use crate::connection::datagram::DatagramTunnel;
275
276 let connection = Arc::clone(&self.transport_connection);
277 let shutdown = self.shutdown_token.child_token();
278 let tunnel = DatagramTunnel::new(connection, shutdown);
279
280 #[cfg(not(feature = "tracing"))]
281 let handle = tokio::spawn(tunnel.accept_loop());
282 #[cfg(feature = "tracing")]
283 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
284
285 handle
286 }
287
288 #[cfg(feature = "tracing")]
289 fn trace_auth(hello: &protocol::ClientHello) {
290 use std::io::Write;
291
292 let mut buf = [0u8; 6];
293 let mut cursor = std::io::Cursor::new(&mut buf[..]);
294
295 for byte in hello.secret.iter().take(3) {
296 let _ = write!(cursor, "{:02x}", byte);
297 }
298
299 if let Ok(hex_str) = std::str::from_utf8(&buf) {
300 tracing::Span::current().record("secret", hex_str);
301 }
302 }
303}
304
305pub struct ConnectionAcceptor<T, A> {
314 acceptor: Arc<T>,
315 authenticator: Arc<A>,
316 connection_semaphore: Arc<Semaphore>,
317 config: Arc<ConnectionConfig>,
318}
319
320impl<T: Acceptor, A: Authenticator<T::Connection> + 'static> ConnectionAcceptor<T, A> {
321 pub fn new(acceptor: T, authenticator: A) -> Self {
325 Self::with_config(
326 acceptor,
327 authenticator,
328 Arc::new(ConnectionConfig::default()),
329 )
330 }
331
332 pub fn with_config(acceptor: T, authenticator: A, config: Arc<ConnectionConfig>) -> Self {
334 let max_connections = config.max_connections();
335 Self {
336 acceptor: Arc::new(acceptor),
337 authenticator: Arc::new(authenticator),
338 connection_semaphore: Arc::new(Semaphore::new(max_connections)),
339 config,
340 }
341 }
342
343 pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
351 loop {
352 tokio::select! {
353 _ = shutdown_rx.recv() => {
354 break;
355 },
356 accepted = self.acceptor.accept() => {
357 Self::handle_incoming_connection(
358 accepted,
359 Arc::clone(&self.authenticator),
360 Arc::clone(&self.connection_semaphore),
361 Arc::clone(&self.config),
362 );
363 },
364 }
365 }
366
367 Ok(())
368 }
369
370 fn handle_incoming_connection(
372 result: io::Result<<T as Acceptor>::Connection>,
373 authenticator: Arc<A>,
374 semaphore: Arc<Semaphore>,
375 config: Arc<ConnectionConfig>,
376 ) {
377 match result {
378 Ok(connection) => match semaphore.try_acquire_owned() {
379 Ok(permit) => {
380 #[cfg(not(feature = "tracing"))]
381 tokio::spawn(Self::process_connection_with_permit(
382 connection,
383 authenticator,
384 permit,
385 config,
386 ));
387 #[cfg(feature = "tracing")]
388 tokio::spawn(
389 Self::process_connection_with_permit(
390 connection,
391 authenticator,
392 permit,
393 config,
394 )
395 .in_current_span(),
396 );
397 }
398 Err(_) => {
399 warn!(
400 "connection rejected: maximum concurrent connections ({}) reached",
401 config.max_connections()
402 );
403 }
404 },
405 Err(err) => {
406 error!("failed to accept connection: {}", err);
407 }
408 }
409 }
410
411 async fn process_connection_with_permit(
415 connection: <T as Acceptor>::Connection,
416 authenticator: Arc<A>,
417 _permit: OwnedSemaphorePermit,
418 config: Arc<ConnectionConfig>,
419 ) {
420 Self::process_connection(connection, authenticator, config).await;
422 }
424
425 #[cfg_attr(feature = "tracing",
426 tracing::instrument(
427 name = "connection",
428 skip_all,
429 fields(
430 id = connection.id(),
431 from = tracing::field::Empty,
432 secret = tracing::field::Empty,
433 reason = tracing::field::Empty
434 )
435 )
436 )]
437 async fn process_connection(
438 connection: <T as Acceptor>::Connection,
439 authenticator: Arc<A>,
440 config: Arc<ConnectionConfig>,
441 ) {
442 #[cfg(feature = "tracing")]
443 if let Ok(addr) = connection.remote_address() {
444 tracing::Span::current().record("from", tracing::field::display(addr));
445 }
446
447 let _result =
448 ClientConnectionProcessor::handle(connection, authenticator.as_ref(), config).await;
449
450 #[cfg(feature = "tracing")]
451 match _result {
452 Ok(_) => {
453 tracing::Span::current().record("reason", "ok");
454 tracing::info!("connection closed");
455 }
456 Err(e) => {
457 tracing::Span::current().record("reason", tracing::field::display(&e));
458 tracing::error!(error = %e, "connection closed with error");
459 }
460 }
461 }
462
463 pub fn local_addr(&self) -> io::Result<SocketAddr> {
464 self.acceptor.local_addr()
465 }
466}
467
468pub struct ConnectionHandle<C> {
469 inner: Arc<C>,
470}
471
472impl<C: Connection> ConnectionHandle<C> {
473 pub fn downgrade_inner(&self) -> Weak<C> {
474 Arc::downgrade(&self.inner)
475 }
476
477 pub fn close(&self, error_code: u32, reason: &[u8]) {
478 self.inner.close(error_code, reason);
479 }
480}
481
482#[derive(Debug, Clone, PartialEq, Eq)]
484pub enum ConnectionAuthError {
485 IncompatibleVersion,
487 InvalidSecret,
489 ServerError,
491 Other(String),
493}
494
495impl From<ConnectionAuthError> for io::Error {
496 fn from(value: ConnectionAuthError) -> Self {
497 match value {
498 ConnectionAuthError::IncompatibleVersion => {
499 io::Error::new(io::ErrorKind::Unsupported, "incompatible protocol version")
500 }
501 ConnectionAuthError::InvalidSecret => io::Error::new(
502 io::ErrorKind::PermissionDenied,
503 "invalid authentication secret",
504 ),
505 ConnectionAuthError::ServerError => io::Error::new(
506 io::ErrorKind::ConnectionAborted,
507 "internal server error during auth",
508 ),
509 ConnectionAuthError::Other(msg) => io::Error::other(msg),
510 }
511 }
512}
513
514pub trait Authenticator<T>: Send + Sync {
520 type AuthContext: Send;
522
523 fn verify(
529 &self,
530 hello: &protocol::ClientHello,
531 ) -> impl Future<Output = Result<Self::AuthContext, ConnectionAuthError>> + Send;
532
533 fn accept(
539 &self,
540 auth_context: Self::AuthContext,
541 connection: ConnectionHandle<T>,
542 ) -> impl Future<Output = ()> + Send;
543}
544
545impl<T: Send + Sync> Authenticator<T> for ombrac::protocol::Secret {
546 type AuthContext = ();
547
548 async fn verify(&self, hello: &protocol::ClientHello) -> Result<(), ConnectionAuthError> {
549 if &hello.secret == self {
550 Ok(())
551 } else {
552 Err(ConnectionAuthError::InvalidSecret)
553 }
554 }
555
556 async fn accept(&self, _auth_context: Self::AuthContext, _connection: ConnectionHandle<T>) {}
557}