Skip to main content

ombrac_client/
client.rs

1use std::io;
2use std::sync::Arc;
3#[cfg(feature = "datagram")]
4use std::sync::atomic::AtomicU64;
5
6use bytes::Bytes;
7#[cfg(feature = "datagram")]
8use tokio_util::sync::CancellationToken;
9
10use ombrac::protocol::{Address, Secret};
11use ombrac_transport::{Connection, Initiator};
12
13use crate::connection::BufferedStream;
14use crate::connection::ClientConnection;
15#[cfg(feature = "datagram")]
16use crate::connection::{UdpDispatcher, UdpSession};
17
18/// The central client responsible for managing the connection to the server.
19///
20/// This client handles TCP stream creation and delegates UDP session management
21/// to a dedicated `UdpDispatcher`. It ensures the connection stays alive
22/// through a retry mechanism.
23pub struct Client<T, C>
24where
25    T: Initiator<Connection = C>,
26    C: Connection,
27{
28    // The connection manager handles handshake, reconnection, and stream creation.
29    connection: Arc<ClientConnection<T, C>>,
30    // The handle to the background UDP dispatcher task.
31    #[cfg(feature = "datagram")]
32    _dispatcher_handle: tokio::task::JoinHandle<()>,
33    #[cfg(feature = "datagram")]
34    session_id_counter: Arc<std::sync::atomic::AtomicU64>,
35    #[cfg(feature = "datagram")]
36    udp_dispatcher: Arc<UdpDispatcher>,
37    #[cfg(feature = "datagram")]
38    shutdown_token: CancellationToken,
39}
40
41impl<T, C> Client<T, C>
42where
43    T: Initiator<Connection = C>,
44    C: Connection,
45{
46    /// Creates a new `Client` and establishes a connection to the server.
47    ///
48    /// This involves performing a handshake and spawning a background task to
49    /// handle incoming UDP datagrams.
50    pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
51        let connection = Arc::new(ClientConnection::new(transport, secret, options).await?);
52
53        #[cfg(feature = "datagram")]
54        let session_id_counter = Arc::new(AtomicU64::new(1));
55        #[cfg(feature = "datagram")]
56        let udp_dispatcher = Arc::new(UdpDispatcher::new());
57        #[cfg(feature = "datagram")]
58        let shutdown_token = CancellationToken::new();
59
60        // Spawn the background task that reads all UDP datagrams and dispatches them.
61        #[cfg(feature = "datagram")]
62        let dispatcher_handle = {
63            let connection_clone = Arc::clone(&connection);
64            let dispatcher_clone = Arc::clone(&udp_dispatcher);
65            let shutdown_clone = shutdown_token.clone();
66            tokio::spawn(async move {
67                UdpDispatcher::run(connection_clone, dispatcher_clone, shutdown_clone).await;
68            })
69        };
70
71        Ok(Self {
72            connection,
73            #[cfg(feature = "datagram")]
74            _dispatcher_handle: dispatcher_handle,
75            #[cfg(feature = "datagram")]
76            session_id_counter,
77            #[cfg(feature = "datagram")]
78            udp_dispatcher,
79            #[cfg(feature = "datagram")]
80            shutdown_token,
81        })
82    }
83
84    /// Establishes a new UDP session through the tunnel.
85    ///
86    /// This returns a `UdpSession` object, which provides a socket-like API
87    /// for sending and receiving UDP datagrams over the existing connection.
88    #[cfg(feature = "datagram")]
89    pub fn open_associate(&self) -> UdpSession<T, C> {
90        let session_id = self
91            .session_id_counter
92            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
93        let receiver = self.udp_dispatcher.register_session(session_id);
94
95        UdpSession::new(
96            session_id,
97            Arc::clone(&self.connection),
98            Arc::clone(&self.udp_dispatcher),
99            receiver,
100        )
101    }
102
103    /// Opens a new bidirectional stream for TCP-like communication.
104    ///
105    /// This method negotiates a new stream with the server, which will then
106    /// connect to the specified destination address. It waits for the server's
107    /// connection response before returning, ensuring proper TCP state handling.
108    ///
109    /// The returned stream is wrapped in a `BufferedStream` to ensure that any
110    /// data remaining in the protocol framing buffer is read first, preventing
111    /// data loss when transitioning from message-based to raw stream communication.
112    pub async fn open_bidirectional(
113        &self,
114        dest_addr: Address,
115    ) -> io::Result<BufferedStream<C::Stream>> {
116        self.connection.open_bidirectional(dest_addr).await
117    }
118
119    /// Rebind the transport to a new socket to ensure a clean state for reconnection.
120    pub async fn rebind(&self) -> io::Result<()> {
121        self.connection.rebind().await
122    }
123}
124
125impl<T, C> Drop for Client<T, C>
126where
127    T: Initiator<Connection = C>,
128    C: Connection,
129{
130    fn drop(&mut self) {
131        // Signal the background dispatcher to shut down when the client is dropped.
132        #[cfg(feature = "datagram")]
133        self.shutdown_token.cancel();
134    }
135}