ombrac_client/
client.rs

1use std::future::Future;
2use std::io;
3use std::sync::Arc;
4#[cfg(feature = "datagram")]
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use arc_swap::{ArcSwap, Guard};
9use bytes::Bytes;
10use futures::{SinkExt, StreamExt};
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Mutex;
13use tokio::time::Instant;
14use tokio_util::codec::Framed;
15#[cfg(feature = "datagram")]
16use tokio_util::sync::CancellationToken;
17
18use ombrac::codec::{UpstreamMessage, length_codec};
19use ombrac::protocol::{
20    self, Address, ClientConnect, ClientHello, HandshakeError, PROTOCOLS_VERSION, Secret,
21    ServerHandshakeResponse,
22};
23use ombrac_macros::{error, info, warn};
24use ombrac_transport::{Connection, Initiator};
25
26#[cfg(feature = "datagram")]
27use datagram::dispatcher::UdpDispatcher;
28#[cfg(feature = "datagram")]
29pub use datagram::session::UdpSession;
30
31struct ReconnectState {
32    last_attempt: Option<Instant>,
33    backoff: Duration,
34}
35
36impl Default for ReconnectState {
37    fn default() -> Self {
38        Self {
39            last_attempt: None,
40            backoff: Duration::from_secs(1),
41        }
42    }
43}
44
45/// The central client responsible for managing the connection to the server.
46///
47/// This client handles TCP stream creation and delegates UDP session management
48/// to a dedicated `UdpDispatcher`. It ensures the connection stays alive
49/// through a retry mechanism.
50pub struct Client<T, C> {
51    // Inner state is Arc'd to be shared with background tasks and UDP sessions.
52    inner: Arc<ClientInner<T, C>>,
53    // The handle to the background UDP dispatcher task.
54    #[cfg(feature = "datagram")]
55    _dispatcher_handle: tokio::task::JoinHandle<()>,
56}
57
58/// The shared inner state of the `Client`.
59///
60/// This struct holds all the components necessary for the client's operation,
61/// such as the transport, the current connection, and session management state.
62pub(crate) struct ClientInner<T, C> {
63    pub(crate) transport: T,
64    pub(crate) connection: ArcSwap<C>,
65    // A lock to ensure only one task attempts to reconnect at a time.
66    reconnect_lock: Mutex<ReconnectState>,
67    secret: Secret,
68    options: Bytes,
69    #[cfg(feature = "datagram")]
70    session_id_counter: AtomicU64,
71    #[cfg(feature = "datagram")]
72    pub(crate) udp_dispatcher: UdpDispatcher,
73    // Token to signal all background tasks to shut down gracefully.
74    #[cfg(feature = "datagram")]
75    pub(crate) shutdown_token: CancellationToken,
76}
77
78impl<T, C> Client<T, C>
79where
80    T: Initiator<Connection = C>,
81    C: Connection,
82{
83    /// Creates a new `Client` and establishes a connection to the server.
84    ///
85    /// This involves performing a handshake and spawning a background task to
86    /// handle incoming UDP datagrams.
87    pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
88        let options = options.unwrap_or_default();
89        let connection = handshake(&transport, secret, options.clone()).await?;
90
91        let inner = Arc::new(ClientInner {
92            transport,
93            connection: ArcSwap::new(Arc::new(connection)),
94            reconnect_lock: Mutex::new(ReconnectState::default()),
95            secret,
96            options,
97            #[cfg(feature = "datagram")]
98            session_id_counter: AtomicU64::new(1),
99            #[cfg(feature = "datagram")]
100            udp_dispatcher: UdpDispatcher::new(),
101            #[cfg(feature = "datagram")]
102            shutdown_token: CancellationToken::new(),
103        });
104
105        // Spawn the background task that reads all UDP datagrams and dispatches them.
106        #[cfg(feature = "datagram")]
107        let dispatcher_handle = tokio::spawn(UdpDispatcher::run(Arc::clone(&inner)));
108
109        Ok(Self {
110            inner,
111            #[cfg(feature = "datagram")]
112            _dispatcher_handle: dispatcher_handle,
113        })
114    }
115
116    /// Establishes a new UDP session through the tunnel.
117    ///
118    /// This returns a `UdpSession` object, which provides a socket-like API
119    /// for sending and receiving UDP datagrams over the existing connection.
120    #[cfg(feature = "datagram")]
121    pub fn open_associate(&self) -> UdpSession<T, C> {
122        let session_id = self.inner.new_session_id();
123        info!(
124            "[Client] New UDP session created with session_id={}",
125            session_id
126        );
127        let receiver = self.inner.udp_dispatcher.register_session(session_id);
128
129        UdpSession::new(session_id, Arc::clone(&self.inner), receiver)
130    }
131
132    /// Opens a new bidirectional stream for TCP-like communication.
133    ///
134    /// This method negotiates a new stream with the server, which will then
135    /// connect to the specified destination address.
136    pub async fn open_bidirectional(&self, dest_addr: Address) -> io::Result<C::Stream> {
137        let mut stream = self
138            .inner
139            .with_retry(|conn| async move { conn.open_bidirectional().await })
140            .await?;
141
142        let connect_message = UpstreamMessage::Connect(ClientConnect { address: dest_addr });
143        let encoded_bytes = protocol::encode(&connect_message)?;
144
145        // The protocol requires a length-prefixed message.
146        stream.write_u32(encoded_bytes.len() as u32).await?;
147        stream.write_all(&encoded_bytes).await?;
148
149        Ok(stream)
150    }
151
152    // Rebind the transport to a new socket to ensure a clean state for reconnection.
153    pub async fn rebind(&self) -> io::Result<()> {
154        self.inner.transport.rebind().await
155    }
156}
157
158impl<T, C> Drop for Client<T, C> {
159    fn drop(&mut self) {
160        // Signal the background dispatcher to shut down when the client is dropped.
161        #[cfg(feature = "datagram")]
162        self.inner.shutdown_token.cancel();
163    }
164}
165
166// --- Internal Implementation ---
167
168impl<T, C> ClientInner<T, C>
169where
170    T: Initiator<Connection = C>,
171    C: Connection,
172{
173    /// Atomically generates a new unique session ID.
174    #[cfg(feature = "datagram")]
175    pub(crate) fn new_session_id(&self) -> u64 {
176        self.session_id_counter.fetch_add(1, Ordering::Relaxed)
177    }
178
179    /// A wrapper function that adds retry/reconnect logic to a connection operation.
180    ///
181    /// It executes the provided `operation`. If the operation fails with a
182    /// connection-related error, it attempts to reconnect and retries the
183    /// operation once.
184    pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
185    where
186        F: Fn(Guard<Arc<C>>) -> Fut,
187        Fut: Future<Output = io::Result<R>>,
188    {
189        let connection = self.connection.load();
190        // Use the pointer address as a unique ID for the connection instance.
191        let old_conn_id = Arc::as_ptr(&connection) as usize;
192
193        match operation(connection).await {
194            Ok(result) => Ok(result),
195            Err(e) if is_connection_error(&e) => {
196                warn!(
197                    "Connection error detected: {}. Attempting to reconnect...",
198                    e
199                );
200                self.reconnect(old_conn_id).await?;
201                let new_connection = self.connection.load();
202                operation(new_connection).await
203            }
204            Err(e) => Err(e),
205        }
206    }
207
208    /// Handles the reconnection logic.
209    ///
210    /// It uses a mutex to prevent multiple tasks from trying to reconnect simultaneously.
211     async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
212        let mut state = self.reconnect_lock.lock().await;
213
214        let current_conn = self.connection.load();
215        let current_conn_id = Arc::as_ptr(&current_conn) as usize;
216        if current_conn_id != old_conn_id {
217            return Ok(());
218        }
219
220        if let Some(last) = state.last_attempt {
221            let elapsed = last.elapsed();
222            if elapsed < state.backoff {
223                let wait_time = state.backoff - elapsed;
224                warn!("Too many reconnect attempts. Global throttling for {:?}...", wait_time);
225
226                drop(state);
227                tokio::time::sleep(wait_time).await; 
228                return Err(io::Error::new(io::ErrorKind::Other, "Reconnect throttled"));
229            }
230        }
231
232        info!("Attemping to reconnect...");
233        
234        state.last_attempt = Some(Instant::now());
235
236        if let Err(e) = self.transport.rebind().await {
237            state.backoff = (state.backoff * 2).min(Duration::from_secs(60));
238            return Err(e);
239        }
240
241        match handshake(&self.transport, self.secret, self.options.clone()).await {
242            Ok(new_connection) => {
243                state.backoff = Duration::from_secs(1);
244                state.last_attempt = None;
245                
246                self.connection.store(Arc::new(new_connection));
247                info!("Reconnection successful");
248                Ok(())
249            }
250            Err(e) => {
251                state.backoff = (state.backoff * 2).min(Duration::from_secs(60));
252                error!("Reconnect failed: {}. Next retry backoff: {:?}", e, state.backoff);
253                Err(e)
254            }
255        }
256    }
257}
258
259/// Performs the initial handshake with the server.
260async fn handshake<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
261where
262    T: Initiator<Connection = C>,
263    C: Connection,
264{
265    const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
266
267    let do_handshake = async {
268        let connection = transport.connect().await?;
269        let mut stream = connection.open_bidirectional().await?;
270
271        let hello_message = UpstreamMessage::Hello(ClientHello {
272            version: PROTOCOLS_VERSION,
273            secret,
274            options,
275        });
276
277        let encoded_bytes = protocol::encode(&hello_message)?;
278        let mut framed = Framed::new(&mut stream, length_codec());
279
280        framed.send(encoded_bytes).await?;
281
282        match framed.next().await {
283            Some(Ok(payload)) => {
284                let response: ServerHandshakeResponse = protocol::decode(&payload)?;
285                match response {
286                    ServerHandshakeResponse::Ok => {
287                        info!("Handshake with server successful");
288                        stream.shutdown().await?;
289                        Ok(connection)
290                    }
291                    ServerHandshakeResponse::Err(e) => {
292                        error!("Handshake failed: {:?}", e);
293                        let err_kind = match e {
294                            HandshakeError::InvalidSecret => io::ErrorKind::PermissionDenied,
295                            _ => io::ErrorKind::InvalidData,
296                        };
297                        Err(io::Error::new(
298                            err_kind,
299                            format!("Server rejected handshake: {:?}", e),
300                        ))
301                    }
302                }
303            }
304            Some(Err(e)) => Err(e),
305            None => Err(io::Error::new(
306                io::ErrorKind::UnexpectedEof,
307                "Connection closed by server during handshake",
308            )),
309        }
310    };
311
312    match tokio::time::timeout(HANDSHAKE_TIMEOUT, do_handshake).await {
313        Ok(result) => result,
314        Err(_) => Err(io::Error::new(
315            io::ErrorKind::TimedOut,
316            "Client hello timed out",
317        )),
318    }
319}
320
321/// Checks if an `io::Error` is related to a lost connection.
322fn is_connection_error(e: &io::Error) -> bool {
323    matches!(
324        e.kind(),
325        io::ErrorKind::ConnectionReset
326            | io::ErrorKind::BrokenPipe
327            | io::ErrorKind::NotConnected
328            | io::ErrorKind::TimedOut
329            | io::ErrorKind::UnexpectedEof
330            | io::ErrorKind::NetworkUnreachable
331    )
332}
333
334/// The `datagram` module encapsulates all logic related to handling UDP datagrams,
335/// including session management, packet fragmentation, and reassembly.
336#[cfg(feature = "datagram")]
337mod datagram {
338    use std::io;
339    use std::sync::Arc;
340    use std::sync::atomic::{AtomicU32, Ordering};
341
342    use bytes::Bytes;
343    use ombrac::protocol::{Address, UdpPacket};
344    use ombrac::reassembly::UdpReassembler;
345    use ombrac_macros::{debug, warn};
346    use ombrac_transport::{Connection, Initiator};
347
348    use super::ClientInner;
349
350    /// Sends a UDP datagram, handling fragmentation if necessary.
351    pub(crate) async fn send_datagram<T, C>(
352        inner: &ClientInner<T, C>,
353        session_id: u64,
354        dest_addr: Address,
355        data: Bytes,
356        fragment_id_counter: &AtomicU32,
357    ) -> io::Result<()>
358    where
359        T: Initiator<Connection = C>,
360        C: Connection,
361    {
362        if data.is_empty() {
363            return Ok(());
364        }
365
366        let connection = inner.connection.load();
367        // Use a conservative default MTU if the transport doesn't provide one.
368        let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
369        // Leave a reasonable margin for headers.
370        let overhead = UdpPacket::fragmented_overhead();
371        let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
372
373        if data.len() <= max_payload_size {
374            let packet = UdpPacket::Unfragmented {
375                session_id,
376                address: dest_addr.clone(),
377                data,
378            };
379            let encoded = packet.encode()?;
380            inner
381                .with_retry(|conn| {
382                    let data_for_attempt = encoded.clone();
383                    async move { conn.send_datagram(data_for_attempt).await }
384                })
385                .await?;
386        } else {
387            // The packet is too large and must be fragmented.
388            debug!(
389                "[Session][{}] Sending packet for {} is too large ({} > max {}), fragmenting...",
390                session_id,
391                dest_addr,
392                data.len(),
393                max_payload_size
394            );
395
396            let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
397            let fragments =
398                UdpPacket::split_packet(session_id, dest_addr, data, max_payload_size, fragment_id);
399
400            for fragment in fragments {
401                let packet_bytes = fragment.encode()?;
402                inner
403                    .with_retry(|conn| {
404                        let data_for_attempt = packet_bytes.clone();
405                        async move { conn.send_datagram(data_for_attempt).await }
406                    })
407                    .await?;
408            }
409        }
410        Ok(())
411    }
412
413    /// Reads a UDP datagram from the connection, handling reassembly.
414    pub(crate) async fn read_datagram<T, C>(
415        inner: &ClientInner<T, C>,
416        reassembler: &mut UdpReassembler,
417    ) -> io::Result<(u64, Address, Bytes)>
418    where
419        T: Initiator<Connection = C>,
420        C: Connection,
421    {
422        loop {
423            let packet_bytes = inner
424                .with_retry(|conn| async move { conn.read_datagram().await })
425                .await?;
426
427            let packet = match UdpPacket::decode(&packet_bytes) {
428                Ok(packet) => packet,
429                Err(_e) => {
430                    warn!("Failed to decode UDP packet: {}. Discarding.", _e);
431                    continue; // Skip malformed packets.
432                }
433            };
434
435            match reassembler.process(packet).await {
436                Ok(Some((session_id, address, data))) => {
437                    return Ok((session_id, address, data));
438                }
439                Ok(None) => {
440                    continue; // Fragment received, continue reading.
441                }
442                Err(_e) => {
443                    warn!("Reassembly error: {}. Discarding fragment.", _e);
444                    continue; // Reassembly error, wait for the next valid packet.
445                }
446            }
447        }
448    }
449
450    /// Contains the `UdpDispatcher` which runs as a background task.
451    pub(crate) mod dispatcher {
452        use std::time::Duration;
453
454        use super::*;
455        use dashmap::DashMap;
456        use tokio::sync::mpsc;
457
458        type UdpSessionSender = mpsc::Sender<(Bytes, Address)>;
459
460        /// Manages all active UDP sessions and dispatches incoming datagrams.
461        pub(crate) struct UdpDispatcher {
462            // Maps a session_id to a sender that forwards data to the `UdpSession`.
463            dispatch_map: DashMap<u64, UdpSessionSender>,
464            fragment_id_counter: AtomicU32,
465        }
466
467        impl UdpDispatcher {
468            pub(crate) fn new() -> Self {
469                Self {
470                    dispatch_map: DashMap::new(),
471                    fragment_id_counter: AtomicU32::new(0),
472                }
473            }
474
475            /// The main loop for the background UDP dispatcher task.
476            ///
477            /// It continuously reads datagrams from the server, reassembles them,
478            /// and forwards them to the correct `UdpSession`.
479            pub(crate) async fn run<T, C>(inner: Arc<ClientInner<T, C>>)
480            where
481                T: Initiator<Connection = C>,
482                C: Connection,
483            {
484                let mut reassembler = UdpReassembler::default();
485                const INITIAL_DELAY: Duration = Duration::from_secs(1);
486                const MAX_DELAY: Duration = Duration::from_secs(60);
487                let mut current_delay = INITIAL_DELAY;
488
489                loop {
490                    tokio::select! {
491                        // Listen for the shutdown signal.
492                        _ = inner.shutdown_token.cancelled() => {
493                            break;
494                        }
495                        // Read the next datagram from the server.
496                        result = read_datagram(&inner, &mut reassembler) => {
497                            match result {
498                                Ok((session_id, address, data)) => {
499                                    if current_delay != INITIAL_DELAY {
500                                        current_delay = INITIAL_DELAY;
501                                    }
502                                    
503                                    inner.udp_dispatcher.dispatch(session_id, data, address).await;
504                                }
505                                Err(_e) => {
506                                    warn!("Error reading datagram: {}. Retrying in {:?}...", _e, current_delay);
507                                    tokio::time::sleep(current_delay).await;
508                                    current_delay = (current_delay * 2).min(MAX_DELAY);
509                                }
510                            }
511                        }
512                    }
513                }
514            }
515
516            /// Forwards a received datagram to the appropriate session.
517            async fn dispatch(&self, session_id: u64, data: Bytes, address: Address) {
518                if let Some(tx) = self.dispatch_map.get(&session_id) {
519                    // If sending fails, the receiver (`UdpSession`) has been dropped.
520                    // It's safe to clean up the entry from the map.
521                    if tx.send((data, address)).await.is_err() {
522                        self.dispatch_map.remove(&session_id);
523                    }
524                } else {
525                    warn!(
526                        "[Session][{}] Received datagram for UNKNOWN or CLOSED",
527                        session_id
528                    );
529                }
530            }
531
532            /// Registers a new session and returns a receiver for it.
533            pub(crate) fn register_session(
534                &self,
535                session_id: u64,
536            ) -> mpsc::Receiver<(Bytes, Address)> {
537                let (tx, rx) = mpsc::channel(128); // Channel buffer size
538                self.dispatch_map.insert(session_id, tx);
539                rx
540            }
541
542            /// Unregisters a session when it is dropped.
543            pub(crate) fn unregister_session(&self, session_id: u64) {
544                self.dispatch_map.remove(&session_id);
545            }
546
547            /// Provides access to the fragment ID counter for sending datagrams.
548            pub(crate) fn fragment_id_counter(&self) -> &AtomicU32 {
549                &self.fragment_id_counter
550            }
551        }
552    }
553
554    /// Represents a virtual UDP session over the tunnel.
555    pub mod session {
556        use super::*;
557        use crate::client::datagram::ClientInner;
558        use tokio::sync::mpsc;
559
560        /// A virtual UDP session that provides a socket-like API.
561        ///
562        /// When this struct is dropped, its session is automatically cleaned up
563        /// on the client side to prevent resource leaks.
564        pub struct UdpSession<T, C>
565        where
566            T: Initiator<Connection = C>,
567            C: Connection,
568        {
569            session_id: u64,
570            client_inner: Arc<ClientInner<T, C>>,
571            receiver: mpsc::Receiver<(Bytes, Address)>,
572        }
573
574        impl<T, C> UdpSession<T, C>
575        where
576            T: Initiator<Connection = C>,
577            C: Connection,
578        {
579            /// Creates a new `UdpSession`. This is called by `Client::new_udp_session`.
580            pub(crate) fn new(
581                session_id: u64,
582                client_inner: Arc<ClientInner<T, C>>,
583                receiver: mpsc::Receiver<(Bytes, Address)>,
584            ) -> Self {
585                Self {
586                    session_id,
587                    client_inner,
588                    receiver,
589                }
590            }
591
592            /// Sends a UDP datagram to the specified destination through the tunnel.
593            pub async fn send_to(&self, data: Bytes, dest_addr: Address) -> io::Result<()> {
594                send_datagram(
595                    &self.client_inner,
596                    self.session_id,
597                    dest_addr,
598                    data,
599                    self.client_inner.udp_dispatcher.fragment_id_counter(),
600                )
601                .await
602            }
603
604            /// Receives a UDP datagram from the tunnel for this session.
605            ///
606            /// Returns the received data and its original sender address.
607            pub async fn recv_from(&mut self) -> Option<(Bytes, Address)> {
608                self.receiver.recv().await
609            }
610        }
611
612        impl<T, C> Drop for UdpSession<T, C>
613        where
614            T: Initiator<Connection = C>,
615            C: Connection,
616        {
617            fn drop(&mut self) {
618                // When a session is dropped, remove its dispatcher from the map
619                // to prevent the map from growing indefinitely.
620                self.client_inner
621                    .udp_dispatcher
622                    .unregister_session(self.session_id);
623            }
624        }
625    }
626}