ombrac_client/
client.rs

1use std::future::Future;
2use std::io;
3use std::sync::Arc;
4#[cfg(feature = "datagram")]
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use arc_swap::{ArcSwap, Guard};
9use bytes::Bytes;
10use futures::{SinkExt, StreamExt};
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Mutex;
13use tokio_util::codec::Framed;
14#[cfg(feature = "datagram")]
15use tokio_util::sync::CancellationToken;
16
17use ombrac::codec::{ServerHandshakeResponse, UpstreamMessage, length_codec};
18use ombrac::protocol::{
19    self, Address, ClientConnect, ClientHello, HandshakeError, PROTOCOLS_VERSION, Secret,
20};
21use ombrac_macros::{error, info, warn};
22use ombrac_transport::{Connection, Initiator};
23
24#[cfg(feature = "datagram")]
25use datagram::dispatcher::UdpDispatcher;
26#[cfg(feature = "datagram")]
27pub use datagram::session::UdpSession;
28
29/// The central client responsible for managing the connection to the server.
30///
31/// This client handles TCP stream creation and delegates UDP session management
32/// to a dedicated `UdpDispatcher`. It ensures the connection stays alive
33/// through a retry mechanism.
34pub struct Client<T, C> {
35    // Inner state is Arc'd to be shared with background tasks and UDP sessions.
36    inner: Arc<ClientInner<T, C>>,
37    // The handle to the background UDP dispatcher task.
38    #[cfg(feature = "datagram")]
39    _dispatcher_handle: tokio::task::JoinHandle<()>,
40}
41
42/// The shared inner state of the `Client`.
43///
44/// This struct holds all the components necessary for the client's operation,
45/// such as the transport, the current connection, and session management state.
46pub(crate) struct ClientInner<T, C> {
47    pub(crate) transport: T,
48    pub(crate) connection: ArcSwap<C>,
49    // A lock to ensure only one task attempts to reconnect at a time.
50    reconnect_lock: Mutex<()>,
51    secret: Secret,
52    options: Bytes,
53    #[cfg(feature = "datagram")]
54    session_id_counter: AtomicU64,
55    #[cfg(feature = "datagram")]
56    pub(crate) udp_dispatcher: UdpDispatcher,
57    // Token to signal all background tasks to shut down gracefully.
58    #[cfg(feature = "datagram")]
59    pub(crate) shutdown_token: CancellationToken,
60}
61
62impl<T, C> Client<T, C>
63where
64    T: Initiator<Connection = C>,
65    C: Connection,
66{
67    /// Creates a new `Client` and establishes a connection to the server.
68    ///
69    /// This involves performing a handshake and spawning a background task to
70    /// handle incoming UDP datagrams.
71    pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
72        let options = options.unwrap_or_default();
73        let connection = handshake(&transport, secret, options.clone()).await?;
74
75        let inner = Arc::new(ClientInner {
76            transport,
77            connection: ArcSwap::new(Arc::new(connection)),
78            reconnect_lock: Mutex::new(()),
79            secret,
80            options,
81            #[cfg(feature = "datagram")]
82            session_id_counter: AtomicU64::new(1),
83            #[cfg(feature = "datagram")]
84            udp_dispatcher: UdpDispatcher::new(),
85            #[cfg(feature = "datagram")]
86            shutdown_token: CancellationToken::new(),
87        });
88
89        // Spawn the background task that reads all UDP datagrams and dispatches them.
90        #[cfg(feature = "datagram")]
91        let dispatcher_handle = tokio::spawn(UdpDispatcher::run(Arc::clone(&inner)));
92
93        Ok(Self {
94            inner,
95            #[cfg(feature = "datagram")]
96            _dispatcher_handle: dispatcher_handle,
97        })
98    }
99
100    /// Establishes a new UDP session through the tunnel.
101    ///
102    /// This returns a `UdpSession` object, which provides a socket-like API
103    /// for sending and receiving UDP datagrams over the existing connection.
104    #[cfg(feature = "datagram")]
105    pub fn open_associate(&self) -> UdpSession<T, C> {
106        let session_id = self.inner.new_session_id();
107        info!(
108            "[Client] New UDP session created with session_id={}",
109            session_id
110        );
111        let receiver = self.inner.udp_dispatcher.register_session(session_id);
112
113        UdpSession::new(session_id, Arc::clone(&self.inner), receiver)
114    }
115
116    /// Opens a new bidirectional stream for TCP-like communication.
117    ///
118    /// This method negotiates a new stream with the server, which will then
119    /// connect to the specified destination address.
120    pub async fn open_bidirectional(&self, dest_addr: Address) -> io::Result<C::Stream> {
121        let mut stream = self
122            .inner
123            .with_retry(|conn| async move { conn.open_bidirectional().await })
124            .await?;
125
126        let connect_message = UpstreamMessage::Connect(ClientConnect { address: dest_addr });
127        let encoded_bytes = protocol::encode(&connect_message)?;
128
129        // The protocol requires a length-prefixed message.
130        stream.write_u32(encoded_bytes.len() as u32).await?;
131        stream.write_all(&encoded_bytes).await?;
132
133        Ok(stream)
134    }
135
136    // Rebind the transport to a new socket to ensure a clean state for reconnection.
137    pub async fn rebind(&self) -> io::Result<()> {
138        self.inner.transport.rebind().await
139    }
140}
141
142impl<T, C> Drop for Client<T, C> {
143    fn drop(&mut self) {
144        // Signal the background dispatcher to shut down when the client is dropped.
145        #[cfg(feature = "datagram")]
146        self.inner.shutdown_token.cancel();
147    }
148}
149
150// --- Internal Implementation ---
151
152impl<T, C> ClientInner<T, C>
153where
154    T: Initiator<Connection = C>,
155    C: Connection,
156{
157    /// Atomically generates a new unique session ID.
158    #[cfg(feature = "datagram")]
159    pub(crate) fn new_session_id(&self) -> u64 {
160        self.session_id_counter.fetch_add(1, Ordering::Relaxed)
161    }
162
163    /// A wrapper function that adds retry/reconnect logic to a connection operation.
164    ///
165    /// It executes the provided `operation`. If the operation fails with a
166    /// connection-related error, it attempts to reconnect and retries the
167    /// operation once.
168    pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
169    where
170        F: Fn(Guard<Arc<C>>) -> Fut,
171        Fut: Future<Output = io::Result<R>>,
172    {
173        let connection = self.connection.load();
174        // Use the pointer address as a unique ID for the connection instance.
175        let old_conn_id = Arc::as_ptr(&connection) as usize;
176
177        match operation(connection).await {
178            Ok(result) => Ok(result),
179            Err(e) if is_connection_error(&e) => {
180                warn!(
181                    "Connection error detected: {}. Attempting to reconnect...",
182                    e
183                );
184                self.reconnect(old_conn_id).await?;
185                let new_connection = self.connection.load();
186                operation(new_connection).await
187            }
188            Err(e) => Err(e),
189        }
190    }
191
192    /// Handles the reconnection logic.
193    ///
194    /// It uses a mutex to prevent multiple tasks from trying to reconnect simultaneously.
195    async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
196        let _lock = self.reconnect_lock.lock().await;
197
198        let current_conn_id = Arc::as_ptr(&self.connection.load()) as usize;
199        // Check if another task has already reconnected.
200        if current_conn_id == old_conn_id {
201            // Rebind the transport to a new socket to ensure a clean state for reconnection.
202            self.transport.rebind().await?;
203
204            let new_connection =
205                handshake(&self.transport, self.secret, self.options.clone()).await?;
206            self.connection.store(Arc::new(new_connection));
207            info!("Reconnection successful");
208        }
209
210        Ok(())
211    }
212}
213
214/// Performs the initial handshake with the server.
215async fn handshake<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
216where
217    T: Initiator<Connection = C>,
218    C: Connection,
219{
220    const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
221
222    let do_handshake = async {
223        let connection = transport.connect().await?;
224        let mut stream = connection.open_bidirectional().await?;
225
226        let hello_message = UpstreamMessage::Hello(ClientHello {
227            version: PROTOCOLS_VERSION,
228            secret,
229            options,
230        });
231
232        let encoded_bytes = protocol::encode(&hello_message)?;
233        let mut framed = Framed::new(&mut stream, length_codec());
234
235        framed.send(encoded_bytes).await?;
236
237        match framed.next().await {
238            Some(Ok(payload)) => {
239                let response: ServerHandshakeResponse = protocol::decode(&payload)?;
240                match response {
241                    ServerHandshakeResponse::Ok => {
242                        info!("Handshake with server successful");
243                        stream.shutdown().await?;
244                        Ok(connection)
245                    }
246                    ServerHandshakeResponse::Err(e) => {
247                        error!("Handshake failed: {:?}", e);
248                        let err_kind = match e {
249                            HandshakeError::InvalidSecret => io::ErrorKind::PermissionDenied,
250                            _ => io::ErrorKind::InvalidData,
251                        };
252                        Err(io::Error::new(
253                            err_kind,
254                            format!("Server rejected handshake: {:?}", e),
255                        ))
256                    }
257                }
258            }
259            Some(Err(e)) => Err(e),
260            None => Err(io::Error::new(
261                io::ErrorKind::UnexpectedEof,
262                "Connection closed by server during handshake",
263            )),
264        }
265    };
266
267    match tokio::time::timeout(HANDSHAKE_TIMEOUT, do_handshake).await {
268        Ok(result) => result,
269        Err(_) => Err(io::Error::new(
270            io::ErrorKind::TimedOut,
271            "Client hello timed out",
272        )),
273    }
274}
275
276/// Checks if an `io::Error` is related to a lost connection.
277fn is_connection_error(e: &io::Error) -> bool {
278    matches!(
279        e.kind(),
280        io::ErrorKind::ConnectionReset
281            | io::ErrorKind::BrokenPipe
282            | io::ErrorKind::NotConnected
283            | io::ErrorKind::TimedOut
284            | io::ErrorKind::UnexpectedEof
285            | io::ErrorKind::NetworkUnreachable
286    )
287}
288
289/// The `datagram` module encapsulates all logic related to handling UDP datagrams,
290/// including session management, packet fragmentation, and reassembly.
291#[cfg(feature = "datagram")]
292mod datagram {
293    use std::io;
294    use std::sync::Arc;
295    use std::sync::atomic::{AtomicU32, Ordering};
296
297    use bytes::Bytes;
298    use ombrac::protocol::{Address, UdpPacket};
299    use ombrac::reassembly::UdpReassembler;
300    use ombrac_macros::{debug, warn};
301    use ombrac_transport::{Connection, Initiator};
302
303    use super::ClientInner;
304
305    /// Sends a UDP datagram, handling fragmentation if necessary.
306    pub(crate) async fn send_datagram<T, C>(
307        inner: &ClientInner<T, C>,
308        session_id: u64,
309        dest_addr: Address,
310        data: Bytes,
311        fragment_id_counter: &AtomicU32,
312    ) -> io::Result<()>
313    where
314        T: Initiator<Connection = C>,
315        C: Connection,
316    {
317        if data.is_empty() {
318            return Ok(());
319        }
320
321        let connection = inner.connection.load();
322        // Use a conservative default MTU if the transport doesn't provide one.
323        let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
324        // Leave a reasonable margin for headers.
325        let overhead = UdpPacket::fragmented_overhead();
326        let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
327
328        if data.len() <= max_payload_size {
329            let packet = UdpPacket::Unfragmented {
330                session_id,
331                address: dest_addr.clone(),
332                data,
333            };
334            let encoded = packet.encode()?;
335            inner
336                .with_retry(|conn| {
337                    let data_for_attempt = encoded.clone();
338                    async move { conn.send_datagram(data_for_attempt).await }
339                })
340                .await?;
341        } else {
342            // The packet is too large and must be fragmented.
343            debug!(
344                "[Session][{}] Sending packet for {} is too large ({} > max {}), fragmenting...",
345                session_id,
346                dest_addr,
347                data.len(),
348                max_payload_size
349            );
350
351            let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
352            let fragments =
353                UdpPacket::split_packet(session_id, dest_addr, data, max_payload_size, fragment_id);
354
355            for fragment in fragments {
356                let packet_bytes = fragment.encode()?;
357                inner
358                    .with_retry(|conn| {
359                        let data_for_attempt = packet_bytes.clone();
360                        async move { conn.send_datagram(data_for_attempt).await }
361                    })
362                    .await?;
363            }
364        }
365        Ok(())
366    }
367
368    /// Reads a UDP datagram from the connection, handling reassembly.
369    pub(crate) async fn read_datagram<T, C>(
370        inner: &ClientInner<T, C>,
371        reassembler: &mut UdpReassembler,
372    ) -> io::Result<(u64, Address, Bytes)>
373    where
374        T: Initiator<Connection = C>,
375        C: Connection,
376    {
377        loop {
378            let packet_bytes = inner
379                .with_retry(|conn| async move { conn.read_datagram().await })
380                .await?;
381
382            let packet = match UdpPacket::decode(&packet_bytes) {
383                Ok(packet) => packet,
384                Err(_e) => {
385                    warn!("Failed to decode UDP packet: {}. Discarding.", _e);
386                    continue; // Skip malformed packets.
387                }
388            };
389
390            match reassembler.process(packet).await {
391                Ok(Some((session_id, address, data))) => {
392                    return Ok((session_id, address, data));
393                }
394                Ok(None) => {
395                    continue; // Fragment received, continue reading.
396                }
397                Err(_e) => {
398                    warn!("Reassembly error: {}. Discarding fragment.", _e);
399                    continue; // Reassembly error, wait for the next valid packet.
400                }
401            }
402        }
403    }
404
405    /// Contains the `UdpDispatcher` which runs as a background task.
406    pub(crate) mod dispatcher {
407        use super::*;
408        use dashmap::DashMap;
409        use tokio::sync::mpsc;
410
411        type UdpSessionSender = mpsc::Sender<(Bytes, Address)>;
412
413        /// Manages all active UDP sessions and dispatches incoming datagrams.
414        pub(crate) struct UdpDispatcher {
415            // Maps a session_id to a sender that forwards data to the `UdpSession`.
416            dispatch_map: DashMap<u64, UdpSessionSender>,
417            fragment_id_counter: AtomicU32,
418        }
419
420        impl UdpDispatcher {
421            pub(crate) fn new() -> Self {
422                Self {
423                    dispatch_map: DashMap::new(),
424                    fragment_id_counter: AtomicU32::new(0),
425                }
426            }
427
428            /// The main loop for the background UDP dispatcher task.
429            ///
430            /// It continuously reads datagrams from the server, reassembles them,
431            /// and forwards them to the correct `UdpSession`.
432            pub(crate) async fn run<T, C>(inner: Arc<ClientInner<T, C>>)
433            where
434                T: Initiator<Connection = C>,
435                C: Connection,
436            {
437                let mut reassembler = UdpReassembler::default();
438
439                loop {
440                    tokio::select! {
441                        // Listen for the shutdown signal.
442                        _ = inner.shutdown_token.cancelled() => {
443                            break;
444                        }
445                        // Read the next datagram from the server.
446                        result = read_datagram(&inner, &mut reassembler) => {
447                            match result {
448                                Ok((session_id, address, data)) => {
449                                    inner.udp_dispatcher.dispatch(session_id, data, address).await;
450                                }
451                                Err(_e) => {
452                                    warn!("Error reading datagram: {}. Retrying after delay...", _e);
453                                     // A small delay to prevent a tight loop on persistent errors.
454                                    tokio::time::sleep(std::time::Duration::from_secs(1)).await;
455                                }
456                            }
457                        }
458                    }
459                }
460            }
461
462            /// Forwards a received datagram to the appropriate session.
463            async fn dispatch(&self, session_id: u64, data: Bytes, address: Address) {
464                if let Some(tx) = self.dispatch_map.get(&session_id) {
465                    // If sending fails, the receiver (`UdpSession`) has been dropped.
466                    // It's safe to clean up the entry from the map.
467                    if tx.send((data, address)).await.is_err() {
468                        self.dispatch_map.remove(&session_id);
469                    }
470                } else {
471                    warn!(
472                        "[Session][{}] Received datagram for UNKNOWN or CLOSED",
473                        session_id
474                    );
475                }
476            }
477
478            /// Registers a new session and returns a receiver for it.
479            pub(crate) fn register_session(
480                &self,
481                session_id: u64,
482            ) -> mpsc::Receiver<(Bytes, Address)> {
483                let (tx, rx) = mpsc::channel(128); // Channel buffer size
484                self.dispatch_map.insert(session_id, tx);
485                rx
486            }
487
488            /// Unregisters a session when it is dropped.
489            pub(crate) fn unregister_session(&self, session_id: u64) {
490                self.dispatch_map.remove(&session_id);
491            }
492
493            /// Provides access to the fragment ID counter for sending datagrams.
494            pub(crate) fn fragment_id_counter(&self) -> &AtomicU32 {
495                &self.fragment_id_counter
496            }
497        }
498    }
499
500    /// Represents a virtual UDP session over the tunnel.
501    pub mod session {
502        use super::*;
503        use crate::client::datagram::ClientInner;
504        use tokio::sync::mpsc;
505
506        /// A virtual UDP session that provides a socket-like API.
507        ///
508        /// When this struct is dropped, its session is automatically cleaned up
509        /// on the client side to prevent resource leaks.
510        pub struct UdpSession<T, C>
511        where
512            T: Initiator<Connection = C>,
513            C: Connection,
514        {
515            session_id: u64,
516            client_inner: Arc<ClientInner<T, C>>,
517            receiver: mpsc::Receiver<(Bytes, Address)>,
518        }
519
520        impl<T, C> UdpSession<T, C>
521        where
522            T: Initiator<Connection = C>,
523            C: Connection,
524        {
525            /// Creates a new `UdpSession`. This is called by `Client::new_udp_session`.
526            pub(crate) fn new(
527                session_id: u64,
528                client_inner: Arc<ClientInner<T, C>>,
529                receiver: mpsc::Receiver<(Bytes, Address)>,
530            ) -> Self {
531                Self {
532                    session_id,
533                    client_inner,
534                    receiver,
535                }
536            }
537
538            /// Sends a UDP datagram to the specified destination through the tunnel.
539            pub async fn send_to(&self, data: Bytes, dest_addr: Address) -> io::Result<()> {
540                send_datagram(
541                    &self.client_inner,
542                    self.session_id,
543                    dest_addr,
544                    data,
545                    self.client_inner.udp_dispatcher.fragment_id_counter(),
546                )
547                .await
548            }
549
550            /// Receives a UDP datagram from the tunnel for this session.
551            ///
552            /// Returns the received data and its original sender address.
553            pub async fn recv_from(&mut self) -> Option<(Bytes, Address)> {
554                self.receiver.recv().await
555            }
556        }
557
558        impl<T, C> Drop for UdpSession<T, C>
559        where
560            T: Initiator<Connection = C>,
561            C: Connection,
562        {
563            fn drop(&mut self) {
564                // When a session is dropped, remove its dispatcher from the map
565                // to prevent the map from growing indefinitely.
566                self.client_inner
567                    .udp_dispatcher
568                    .unregister_session(self.session_id);
569            }
570        }
571    }
572}