ombrac_client/
client.rs

1use std::future::Future;
2use std::io;
3use std::sync::Arc;
4#[cfg(feature = "datagram")]
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use arc_swap::{ArcSwap, Guard};
9use bytes::Bytes;
10use futures::{SinkExt, StreamExt};
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Mutex;
13use tokio_util::codec::Framed;
14#[cfg(feature = "datagram")]
15use tokio_util::sync::CancellationToken;
16
17use ombrac::codec::{UpstreamMessage, length_codec};
18use ombrac::protocol::{
19    self, Address, ClientConnect, ClientHello, HandshakeError, PROTOCOLS_VERSION, Secret,
20    ServerHandshakeResponse,
21};
22use ombrac_macros::{error, info, warn};
23use ombrac_transport::{Connection, Initiator};
24
25#[cfg(feature = "datagram")]
26use datagram::dispatcher::UdpDispatcher;
27#[cfg(feature = "datagram")]
28pub use datagram::session::UdpSession;
29
30/// The central client responsible for managing the connection to the server.
31///
32/// This client handles TCP stream creation and delegates UDP session management
33/// to a dedicated `UdpDispatcher`. It ensures the connection stays alive
34/// through a retry mechanism.
35pub struct Client<T, C> {
36    // Inner state is Arc'd to be shared with background tasks and UDP sessions.
37    inner: Arc<ClientInner<T, C>>,
38    // The handle to the background UDP dispatcher task.
39    #[cfg(feature = "datagram")]
40    _dispatcher_handle: tokio::task::JoinHandle<()>,
41}
42
43/// The shared inner state of the `Client`.
44///
45/// This struct holds all the components necessary for the client's operation,
46/// such as the transport, the current connection, and session management state.
47pub(crate) struct ClientInner<T, C> {
48    pub(crate) transport: T,
49    pub(crate) connection: ArcSwap<C>,
50    // A lock to ensure only one task attempts to reconnect at a time.
51    reconnect_lock: Mutex<()>,
52    secret: Secret,
53    options: Bytes,
54    #[cfg(feature = "datagram")]
55    session_id_counter: AtomicU64,
56    #[cfg(feature = "datagram")]
57    pub(crate) udp_dispatcher: UdpDispatcher,
58    // Token to signal all background tasks to shut down gracefully.
59    #[cfg(feature = "datagram")]
60    pub(crate) shutdown_token: CancellationToken,
61}
62
63impl<T, C> Client<T, C>
64where
65    T: Initiator<Connection = C>,
66    C: Connection,
67{
68    /// Creates a new `Client` and establishes a connection to the server.
69    ///
70    /// This involves performing a handshake and spawning a background task to
71    /// handle incoming UDP datagrams.
72    pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
73        let options = options.unwrap_or_default();
74        let connection = handshake(&transport, secret, options.clone()).await?;
75
76        let inner = Arc::new(ClientInner {
77            transport,
78            connection: ArcSwap::new(Arc::new(connection)),
79            reconnect_lock: Mutex::new(()),
80            secret,
81            options,
82            #[cfg(feature = "datagram")]
83            session_id_counter: AtomicU64::new(1),
84            #[cfg(feature = "datagram")]
85            udp_dispatcher: UdpDispatcher::new(),
86            #[cfg(feature = "datagram")]
87            shutdown_token: CancellationToken::new(),
88        });
89
90        // Spawn the background task that reads all UDP datagrams and dispatches them.
91        #[cfg(feature = "datagram")]
92        let dispatcher_handle = tokio::spawn(UdpDispatcher::run(Arc::clone(&inner)));
93
94        Ok(Self {
95            inner,
96            #[cfg(feature = "datagram")]
97            _dispatcher_handle: dispatcher_handle,
98        })
99    }
100
101    /// Establishes a new UDP session through the tunnel.
102    ///
103    /// This returns a `UdpSession` object, which provides a socket-like API
104    /// for sending and receiving UDP datagrams over the existing connection.
105    #[cfg(feature = "datagram")]
106    pub fn open_associate(&self) -> UdpSession<T, C> {
107        let session_id = self.inner.new_session_id();
108        info!(
109            "[Client] New UDP session created with session_id={}",
110            session_id
111        );
112        let receiver = self.inner.udp_dispatcher.register_session(session_id);
113
114        UdpSession::new(session_id, Arc::clone(&self.inner), receiver)
115    }
116
117    /// Opens a new bidirectional stream for TCP-like communication.
118    ///
119    /// This method negotiates a new stream with the server, which will then
120    /// connect to the specified destination address.
121    pub async fn open_bidirectional(&self, dest_addr: Address) -> io::Result<C::Stream> {
122        let mut stream = self
123            .inner
124            .with_retry(|conn| async move { conn.open_bidirectional().await })
125            .await?;
126
127        let connect_message = UpstreamMessage::Connect(ClientConnect { address: dest_addr });
128        let encoded_bytes = protocol::encode(&connect_message)?;
129
130        // The protocol requires a length-prefixed message.
131        stream.write_u32(encoded_bytes.len() as u32).await?;
132        stream.write_all(&encoded_bytes).await?;
133
134        Ok(stream)
135    }
136
137    // Rebind the transport to a new socket to ensure a clean state for reconnection.
138    pub async fn rebind(&self) -> io::Result<()> {
139        self.inner.transport.rebind().await
140    }
141}
142
143impl<T, C> Drop for Client<T, C> {
144    fn drop(&mut self) {
145        // Signal the background dispatcher to shut down when the client is dropped.
146        #[cfg(feature = "datagram")]
147        self.inner.shutdown_token.cancel();
148    }
149}
150
151// --- Internal Implementation ---
152
153impl<T, C> ClientInner<T, C>
154where
155    T: Initiator<Connection = C>,
156    C: Connection,
157{
158    /// Atomically generates a new unique session ID.
159    #[cfg(feature = "datagram")]
160    pub(crate) fn new_session_id(&self) -> u64 {
161        self.session_id_counter.fetch_add(1, Ordering::Relaxed)
162    }
163
164    /// A wrapper function that adds retry/reconnect logic to a connection operation.
165    ///
166    /// It executes the provided `operation`. If the operation fails with a
167    /// connection-related error, it attempts to reconnect and retries the
168    /// operation once.
169    pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
170    where
171        F: Fn(Guard<Arc<C>>) -> Fut,
172        Fut: Future<Output = io::Result<R>>,
173    {
174        let connection = self.connection.load();
175        // Use the pointer address as a unique ID for the connection instance.
176        let old_conn_id = Arc::as_ptr(&connection) as usize;
177
178        match operation(connection).await {
179            Ok(result) => Ok(result),
180            Err(e) if is_connection_error(&e) => {
181                warn!(
182                    "Connection error detected: {}. Attempting to reconnect...",
183                    e
184                );
185                self.reconnect(old_conn_id).await?;
186                let new_connection = self.connection.load();
187                operation(new_connection).await
188            }
189            Err(e) => Err(e),
190        }
191    }
192
193    /// Handles the reconnection logic.
194    ///
195    /// It uses a mutex to prevent multiple tasks from trying to reconnect simultaneously.
196    async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
197        let _lock = self.reconnect_lock.lock().await;
198
199        let current_conn_id = Arc::as_ptr(&self.connection.load()) as usize;
200        // Check if another task has already reconnected.
201        if current_conn_id == old_conn_id {
202            // Rebind the transport to a new socket to ensure a clean state for reconnection.
203            self.transport.rebind().await?;
204
205            let new_connection =
206                handshake(&self.transport, self.secret, self.options.clone()).await?;
207            self.connection.store(Arc::new(new_connection));
208            info!("Reconnection successful");
209        }
210
211        Ok(())
212    }
213}
214
215/// Performs the initial handshake with the server.
216async fn handshake<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
217where
218    T: Initiator<Connection = C>,
219    C: Connection,
220{
221    const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
222
223    let do_handshake = async {
224        let connection = transport.connect().await?;
225        let mut stream = connection.open_bidirectional().await?;
226
227        let hello_message = UpstreamMessage::Hello(ClientHello {
228            version: PROTOCOLS_VERSION,
229            secret,
230            options,
231        });
232
233        let encoded_bytes = protocol::encode(&hello_message)?;
234        let mut framed = Framed::new(&mut stream, length_codec());
235
236        framed.send(encoded_bytes).await?;
237
238        match framed.next().await {
239            Some(Ok(payload)) => {
240                let response: ServerHandshakeResponse = protocol::decode(&payload)?;
241                match response {
242                    ServerHandshakeResponse::Ok => {
243                        info!("Handshake with server successful");
244                        stream.shutdown().await?;
245                        Ok(connection)
246                    }
247                    ServerHandshakeResponse::Err(e) => {
248                        error!("Handshake failed: {:?}", e);
249                        let err_kind = match e {
250                            HandshakeError::InvalidSecret => io::ErrorKind::PermissionDenied,
251                            _ => io::ErrorKind::InvalidData,
252                        };
253                        Err(io::Error::new(
254                            err_kind,
255                            format!("Server rejected handshake: {:?}", e),
256                        ))
257                    }
258                }
259            }
260            Some(Err(e)) => Err(e),
261            None => Err(io::Error::new(
262                io::ErrorKind::UnexpectedEof,
263                "Connection closed by server during handshake",
264            )),
265        }
266    };
267
268    match tokio::time::timeout(HANDSHAKE_TIMEOUT, do_handshake).await {
269        Ok(result) => result,
270        Err(_) => Err(io::Error::new(
271            io::ErrorKind::TimedOut,
272            "Client hello timed out",
273        )),
274    }
275}
276
277/// Checks if an `io::Error` is related to a lost connection.
278fn is_connection_error(e: &io::Error) -> bool {
279    matches!(
280        e.kind(),
281        io::ErrorKind::ConnectionReset
282            | io::ErrorKind::BrokenPipe
283            | io::ErrorKind::NotConnected
284            | io::ErrorKind::TimedOut
285            | io::ErrorKind::UnexpectedEof
286            | io::ErrorKind::NetworkUnreachable
287    )
288}
289
290/// The `datagram` module encapsulates all logic related to handling UDP datagrams,
291/// including session management, packet fragmentation, and reassembly.
292#[cfg(feature = "datagram")]
293mod datagram {
294    use std::io;
295    use std::sync::Arc;
296    use std::sync::atomic::{AtomicU32, Ordering};
297
298    use bytes::Bytes;
299    use ombrac::protocol::{Address, UdpPacket};
300    use ombrac::reassembly::UdpReassembler;
301    use ombrac_macros::{debug, warn};
302    use ombrac_transport::{Connection, Initiator};
303
304    use super::ClientInner;
305
306    /// Sends a UDP datagram, handling fragmentation if necessary.
307    pub(crate) async fn send_datagram<T, C>(
308        inner: &ClientInner<T, C>,
309        session_id: u64,
310        dest_addr: Address,
311        data: Bytes,
312        fragment_id_counter: &AtomicU32,
313    ) -> io::Result<()>
314    where
315        T: Initiator<Connection = C>,
316        C: Connection,
317    {
318        if data.is_empty() {
319            return Ok(());
320        }
321
322        let connection = inner.connection.load();
323        // Use a conservative default MTU if the transport doesn't provide one.
324        let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
325        // Leave a reasonable margin for headers.
326        let overhead = UdpPacket::fragmented_overhead();
327        let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
328
329        if data.len() <= max_payload_size {
330            let packet = UdpPacket::Unfragmented {
331                session_id,
332                address: dest_addr.clone(),
333                data,
334            };
335            let encoded = packet.encode()?;
336            inner
337                .with_retry(|conn| {
338                    let data_for_attempt = encoded.clone();
339                    async move { conn.send_datagram(data_for_attempt).await }
340                })
341                .await?;
342        } else {
343            // The packet is too large and must be fragmented.
344            debug!(
345                "[Session][{}] Sending packet for {} is too large ({} > max {}), fragmenting...",
346                session_id,
347                dest_addr,
348                data.len(),
349                max_payload_size
350            );
351
352            let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
353            let fragments =
354                UdpPacket::split_packet(session_id, dest_addr, data, max_payload_size, fragment_id);
355
356            for fragment in fragments {
357                let packet_bytes = fragment.encode()?;
358                inner
359                    .with_retry(|conn| {
360                        let data_for_attempt = packet_bytes.clone();
361                        async move { conn.send_datagram(data_for_attempt).await }
362                    })
363                    .await?;
364            }
365        }
366        Ok(())
367    }
368
369    /// Reads a UDP datagram from the connection, handling reassembly.
370    pub(crate) async fn read_datagram<T, C>(
371        inner: &ClientInner<T, C>,
372        reassembler: &mut UdpReassembler,
373    ) -> io::Result<(u64, Address, Bytes)>
374    where
375        T: Initiator<Connection = C>,
376        C: Connection,
377    {
378        loop {
379            let packet_bytes = inner
380                .with_retry(|conn| async move { conn.read_datagram().await })
381                .await?;
382
383            let packet = match UdpPacket::decode(&packet_bytes) {
384                Ok(packet) => packet,
385                Err(_e) => {
386                    warn!("Failed to decode UDP packet: {}. Discarding.", _e);
387                    continue; // Skip malformed packets.
388                }
389            };
390
391            match reassembler.process(packet).await {
392                Ok(Some((session_id, address, data))) => {
393                    return Ok((session_id, address, data));
394                }
395                Ok(None) => {
396                    continue; // Fragment received, continue reading.
397                }
398                Err(_e) => {
399                    warn!("Reassembly error: {}. Discarding fragment.", _e);
400                    continue; // Reassembly error, wait for the next valid packet.
401                }
402            }
403        }
404    }
405
406    /// Contains the `UdpDispatcher` which runs as a background task.
407    pub(crate) mod dispatcher {
408        use super::*;
409        use dashmap::DashMap;
410        use tokio::sync::mpsc;
411
412        type UdpSessionSender = mpsc::Sender<(Bytes, Address)>;
413
414        /// Manages all active UDP sessions and dispatches incoming datagrams.
415        pub(crate) struct UdpDispatcher {
416            // Maps a session_id to a sender that forwards data to the `UdpSession`.
417            dispatch_map: DashMap<u64, UdpSessionSender>,
418            fragment_id_counter: AtomicU32,
419        }
420
421        impl UdpDispatcher {
422            pub(crate) fn new() -> Self {
423                Self {
424                    dispatch_map: DashMap::new(),
425                    fragment_id_counter: AtomicU32::new(0),
426                }
427            }
428
429            /// The main loop for the background UDP dispatcher task.
430            ///
431            /// It continuously reads datagrams from the server, reassembles them,
432            /// and forwards them to the correct `UdpSession`.
433            pub(crate) async fn run<T, C>(inner: Arc<ClientInner<T, C>>)
434            where
435                T: Initiator<Connection = C>,
436                C: Connection,
437            {
438                let mut reassembler = UdpReassembler::default();
439
440                loop {
441                    tokio::select! {
442                        // Listen for the shutdown signal.
443                        _ = inner.shutdown_token.cancelled() => {
444                            break;
445                        }
446                        // Read the next datagram from the server.
447                        result = read_datagram(&inner, &mut reassembler) => {
448                            match result {
449                                Ok((session_id, address, data)) => {
450                                    inner.udp_dispatcher.dispatch(session_id, data, address).await;
451                                }
452                                Err(_e) => {
453                                    warn!("Error reading datagram: {}. Retrying after delay...", _e);
454                                     // A small delay to prevent a tight loop on persistent errors.
455                                    tokio::time::sleep(std::time::Duration::from_secs(1)).await;
456                                }
457                            }
458                        }
459                    }
460                }
461            }
462
463            /// Forwards a received datagram to the appropriate session.
464            async fn dispatch(&self, session_id: u64, data: Bytes, address: Address) {
465                if let Some(tx) = self.dispatch_map.get(&session_id) {
466                    // If sending fails, the receiver (`UdpSession`) has been dropped.
467                    // It's safe to clean up the entry from the map.
468                    if tx.send((data, address)).await.is_err() {
469                        self.dispatch_map.remove(&session_id);
470                    }
471                } else {
472                    warn!(
473                        "[Session][{}] Received datagram for UNKNOWN or CLOSED",
474                        session_id
475                    );
476                }
477            }
478
479            /// Registers a new session and returns a receiver for it.
480            pub(crate) fn register_session(
481                &self,
482                session_id: u64,
483            ) -> mpsc::Receiver<(Bytes, Address)> {
484                let (tx, rx) = mpsc::channel(128); // Channel buffer size
485                self.dispatch_map.insert(session_id, tx);
486                rx
487            }
488
489            /// Unregisters a session when it is dropped.
490            pub(crate) fn unregister_session(&self, session_id: u64) {
491                self.dispatch_map.remove(&session_id);
492            }
493
494            /// Provides access to the fragment ID counter for sending datagrams.
495            pub(crate) fn fragment_id_counter(&self) -> &AtomicU32 {
496                &self.fragment_id_counter
497            }
498        }
499    }
500
501    /// Represents a virtual UDP session over the tunnel.
502    pub mod session {
503        use super::*;
504        use crate::client::datagram::ClientInner;
505        use tokio::sync::mpsc;
506
507        /// A virtual UDP session that provides a socket-like API.
508        ///
509        /// When this struct is dropped, its session is automatically cleaned up
510        /// on the client side to prevent resource leaks.
511        pub struct UdpSession<T, C>
512        where
513            T: Initiator<Connection = C>,
514            C: Connection,
515        {
516            session_id: u64,
517            client_inner: Arc<ClientInner<T, C>>,
518            receiver: mpsc::Receiver<(Bytes, Address)>,
519        }
520
521        impl<T, C> UdpSession<T, C>
522        where
523            T: Initiator<Connection = C>,
524            C: Connection,
525        {
526            /// Creates a new `UdpSession`. This is called by `Client::new_udp_session`.
527            pub(crate) fn new(
528                session_id: u64,
529                client_inner: Arc<ClientInner<T, C>>,
530                receiver: mpsc::Receiver<(Bytes, Address)>,
531            ) -> Self {
532                Self {
533                    session_id,
534                    client_inner,
535                    receiver,
536                }
537            }
538
539            /// Sends a UDP datagram to the specified destination through the tunnel.
540            pub async fn send_to(&self, data: Bytes, dest_addr: Address) -> io::Result<()> {
541                send_datagram(
542                    &self.client_inner,
543                    self.session_id,
544                    dest_addr,
545                    data,
546                    self.client_inner.udp_dispatcher.fragment_id_counter(),
547                )
548                .await
549            }
550
551            /// Receives a UDP datagram from the tunnel for this session.
552            ///
553            /// Returns the received data and its original sender address.
554            pub async fn recv_from(&mut self) -> Option<(Bytes, Address)> {
555                self.receiver.recv().await
556            }
557        }
558
559        impl<T, C> Drop for UdpSession<T, C>
560        where
561            T: Initiator<Connection = C>,
562            C: Connection,
563        {
564            fn drop(&mut self) {
565                // When a session is dropped, remove its dispatcher from the map
566                // to prevent the map from growing indefinitely.
567                self.client_inner
568                    .udp_dispatcher
569                    .unregister_session(self.session_id);
570            }
571        }
572    }
573}