1use std::future::Future;
2use std::io;
3use std::sync::Arc;
4#[cfg(feature = "datagram")]
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use arc_swap::{ArcSwap, Guard};
9use bytes::Bytes;
10use futures::{SinkExt, StreamExt};
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Mutex;
13use tokio_util::codec::Framed;
14#[cfg(feature = "datagram")]
15use tokio_util::sync::CancellationToken;
16
17use ombrac::codec::{ServerHandshakeResponse, UpstreamMessage, length_codec};
18use ombrac::protocol::{
19 self, Address, ClientConnect, ClientHello, HandshakeError, PROTOCOLS_VERSION, Secret,
20};
21use ombrac_macros::{error, info, warn};
22use ombrac_transport::{Connection, Initiator};
23
24#[cfg(feature = "datagram")]
25use datagram::dispatcher::UdpDispatcher;
26#[cfg(feature = "datagram")]
27pub use datagram::session::UdpSession;
28
29pub struct Client<T, C> {
35 inner: Arc<ClientInner<T, C>>,
37 #[cfg(feature = "datagram")]
39 _dispatcher_handle: tokio::task::JoinHandle<()>,
40}
41
42pub(crate) struct ClientInner<T, C> {
47 pub(crate) transport: T,
48 pub(crate) connection: ArcSwap<C>,
49 reconnect_lock: Mutex<()>,
51 secret: Secret,
52 options: Bytes,
53 #[cfg(feature = "datagram")]
54 session_id_counter: AtomicU64,
55 #[cfg(feature = "datagram")]
56 pub(crate) udp_dispatcher: UdpDispatcher,
57 #[cfg(feature = "datagram")]
59 pub(crate) shutdown_token: CancellationToken,
60}
61
62impl<T, C> Client<T, C>
63where
64 T: Initiator<Connection = C>,
65 C: Connection,
66{
67 pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
72 let options = options.unwrap_or_default();
73 let connection = handshake(&transport, secret, options.clone()).await?;
74
75 let inner = Arc::new(ClientInner {
76 transport,
77 connection: ArcSwap::new(Arc::new(connection)),
78 reconnect_lock: Mutex::new(()),
79 secret,
80 options,
81 #[cfg(feature = "datagram")]
82 session_id_counter: AtomicU64::new(1),
83 #[cfg(feature = "datagram")]
84 udp_dispatcher: UdpDispatcher::new(),
85 #[cfg(feature = "datagram")]
86 shutdown_token: CancellationToken::new(),
87 });
88
89 #[cfg(feature = "datagram")]
91 let dispatcher_handle = tokio::spawn(UdpDispatcher::run(Arc::clone(&inner)));
92
93 Ok(Self {
94 inner,
95 #[cfg(feature = "datagram")]
96 _dispatcher_handle: dispatcher_handle,
97 })
98 }
99
100 #[cfg(feature = "datagram")]
105 pub fn open_associate(&self) -> UdpSession<T, C> {
106 let session_id = self.inner.new_session_id();
107 info!(
108 "[Client] New UDP session created with session_id={}",
109 session_id
110 );
111 let receiver = self.inner.udp_dispatcher.register_session(session_id);
112
113 UdpSession::new(session_id, Arc::clone(&self.inner), receiver)
114 }
115
116 pub async fn open_bidirectional(&self, dest_addr: Address) -> io::Result<C::Stream> {
121 let mut stream = self
122 .inner
123 .with_retry(|conn| async move { conn.open_bidirectional().await })
124 .await?;
125
126 let connect_message = UpstreamMessage::Connect(ClientConnect { address: dest_addr });
127 let encoded_bytes = protocol::encode(&connect_message)?;
128
129 stream.write_u32(encoded_bytes.len() as u32).await?;
131 stream.write_all(&encoded_bytes).await?;
132
133 Ok(stream)
134 }
135
136 pub async fn rebind(&self) -> io::Result<()> {
138 self.inner.transport.rebind().await
139 }
140}
141
142impl<T, C> Drop for Client<T, C> {
143 fn drop(&mut self) {
144 #[cfg(feature = "datagram")]
146 self.inner.shutdown_token.cancel();
147 }
148}
149
150impl<T, C> ClientInner<T, C>
153where
154 T: Initiator<Connection = C>,
155 C: Connection,
156{
157 #[cfg(feature = "datagram")]
159 pub(crate) fn new_session_id(&self) -> u64 {
160 self.session_id_counter.fetch_add(1, Ordering::Relaxed)
161 }
162
163 pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
169 where
170 F: Fn(Guard<Arc<C>>) -> Fut,
171 Fut: Future<Output = io::Result<R>>,
172 {
173 let connection = self.connection.load();
174 let old_conn_id = Arc::as_ptr(&connection) as usize;
176
177 match operation(connection).await {
178 Ok(result) => Ok(result),
179 Err(e) if is_connection_error(&e) => {
180 warn!(
181 "Connection error detected: {}. Attempting to reconnect...",
182 e
183 );
184 self.reconnect(old_conn_id).await?;
185 let new_connection = self.connection.load();
186 operation(new_connection).await
187 }
188 Err(e) => Err(e),
189 }
190 }
191
192 async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
196 let _lock = self.reconnect_lock.lock().await;
197
198 let current_conn_id = Arc::as_ptr(&self.connection.load()) as usize;
199 if current_conn_id == old_conn_id {
201 self.transport.rebind().await?;
203
204 let new_connection =
205 handshake(&self.transport, self.secret, self.options.clone()).await?;
206 self.connection.store(Arc::new(new_connection));
207 info!("Reconnection successful");
208 }
209
210 Ok(())
211 }
212}
213
214async fn handshake<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
216where
217 T: Initiator<Connection = C>,
218 C: Connection,
219{
220 const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
221
222 let do_handshake = async {
223 let connection = transport.connect().await?;
224 let mut stream = connection.open_bidirectional().await?;
225
226 let hello_message = UpstreamMessage::Hello(ClientHello {
227 version: PROTOCOLS_VERSION,
228 secret,
229 options,
230 });
231
232 let encoded_bytes = protocol::encode(&hello_message)?;
233 let mut framed = Framed::new(&mut stream, length_codec());
234
235 framed.send(encoded_bytes).await?;
236
237 match framed.next().await {
238 Some(Ok(payload)) => {
239 let response: ServerHandshakeResponse = protocol::decode(&payload)?;
240 match response {
241 ServerHandshakeResponse::Ok => {
242 info!("Handshake with server successful");
243 stream.shutdown().await?;
244 Ok(connection)
245 }
246 ServerHandshakeResponse::Err(e) => {
247 error!("Handshake failed: {:?}", e);
248 let err_kind = match e {
249 HandshakeError::InvalidSecret => io::ErrorKind::PermissionDenied,
250 _ => io::ErrorKind::InvalidData,
251 };
252 Err(io::Error::new(
253 err_kind,
254 format!("Server rejected handshake: {:?}", e),
255 ))
256 }
257 }
258 }
259 Some(Err(e)) => Err(e),
260 None => Err(io::Error::new(
261 io::ErrorKind::UnexpectedEof,
262 "Connection closed by server during handshake",
263 )),
264 }
265 };
266
267 match tokio::time::timeout(HANDSHAKE_TIMEOUT, do_handshake).await {
268 Ok(result) => result,
269 Err(_) => Err(io::Error::new(
270 io::ErrorKind::TimedOut,
271 "Client hello timed out",
272 )),
273 }
274}
275
276fn is_connection_error(e: &io::Error) -> bool {
278 matches!(
279 e.kind(),
280 io::ErrorKind::ConnectionReset
281 | io::ErrorKind::BrokenPipe
282 | io::ErrorKind::NotConnected
283 | io::ErrorKind::TimedOut
284 | io::ErrorKind::UnexpectedEof
285 | io::ErrorKind::NetworkUnreachable
286 )
287}
288
289#[cfg(feature = "datagram")]
292mod datagram {
293 use std::io;
294 use std::sync::Arc;
295 use std::sync::atomic::{AtomicU32, Ordering};
296
297 use bytes::Bytes;
298 use ombrac::protocol::{Address, UdpPacket};
299 use ombrac::reassembly::UdpReassembler;
300 use ombrac_macros::{debug, warn};
301 use ombrac_transport::{Connection, Initiator};
302
303 use super::ClientInner;
304
305 pub(crate) async fn send_datagram<T, C>(
307 inner: &ClientInner<T, C>,
308 session_id: u64,
309 dest_addr: Address,
310 data: Bytes,
311 fragment_id_counter: &AtomicU32,
312 ) -> io::Result<()>
313 where
314 T: Initiator<Connection = C>,
315 C: Connection,
316 {
317 if data.is_empty() {
318 return Ok(());
319 }
320
321 let connection = inner.connection.load();
322 let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
324 let overhead = UdpPacket::fragmented_overhead();
326 let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
327
328 if data.len() <= max_payload_size {
329 let packet = UdpPacket::Unfragmented {
330 session_id,
331 address: dest_addr.clone(),
332 data,
333 };
334 let encoded = packet.encode()?;
335 inner
336 .with_retry(|conn| {
337 let data_for_attempt = encoded.clone();
338 async move { conn.send_datagram(data_for_attempt).await }
339 })
340 .await?;
341 } else {
342 debug!(
344 "[Session][{}] Sending packet for {} is too large ({} > max {}), fragmenting...",
345 session_id,
346 dest_addr,
347 data.len(),
348 max_payload_size
349 );
350
351 let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
352 let fragments =
353 UdpPacket::split_packet(session_id, dest_addr, data, max_payload_size, fragment_id);
354
355 for fragment in fragments {
356 let packet_bytes = fragment.encode()?;
357 inner
358 .with_retry(|conn| {
359 let data_for_attempt = packet_bytes.clone();
360 async move { conn.send_datagram(data_for_attempt).await }
361 })
362 .await?;
363 }
364 }
365 Ok(())
366 }
367
368 pub(crate) async fn read_datagram<T, C>(
370 inner: &ClientInner<T, C>,
371 reassembler: &mut UdpReassembler,
372 ) -> io::Result<(u64, Address, Bytes)>
373 where
374 T: Initiator<Connection = C>,
375 C: Connection,
376 {
377 loop {
378 let packet_bytes = inner
379 .with_retry(|conn| async move { conn.read_datagram().await })
380 .await?;
381
382 let packet = match UdpPacket::decode(&packet_bytes) {
383 Ok(packet) => packet,
384 Err(_e) => {
385 warn!("Failed to decode UDP packet: {}. Discarding.", _e);
386 continue; }
388 };
389
390 match reassembler.process(packet).await {
391 Ok(Some((session_id, address, data))) => {
392 return Ok((session_id, address, data));
393 }
394 Ok(None) => {
395 continue; }
397 Err(_e) => {
398 warn!("Reassembly error: {}. Discarding fragment.", _e);
399 continue; }
401 }
402 }
403 }
404
405 pub(crate) mod dispatcher {
407 use super::*;
408 use dashmap::DashMap;
409 use tokio::sync::mpsc;
410
411 type UdpSessionSender = mpsc::Sender<(Bytes, Address)>;
412
413 pub(crate) struct UdpDispatcher {
415 dispatch_map: DashMap<u64, UdpSessionSender>,
417 fragment_id_counter: AtomicU32,
418 }
419
420 impl UdpDispatcher {
421 pub(crate) fn new() -> Self {
422 Self {
423 dispatch_map: DashMap::new(),
424 fragment_id_counter: AtomicU32::new(0),
425 }
426 }
427
428 pub(crate) async fn run<T, C>(inner: Arc<ClientInner<T, C>>)
433 where
434 T: Initiator<Connection = C>,
435 C: Connection,
436 {
437 let mut reassembler = UdpReassembler::default();
438
439 loop {
440 tokio::select! {
441 _ = inner.shutdown_token.cancelled() => {
443 break;
444 }
445 result = read_datagram(&inner, &mut reassembler) => {
447 match result {
448 Ok((session_id, address, data)) => {
449 inner.udp_dispatcher.dispatch(session_id, data, address).await;
450 }
451 Err(_e) => {
452 warn!("Error reading datagram: {}. Retrying after delay...", _e);
453 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
455 }
456 }
457 }
458 }
459 }
460 }
461
462 async fn dispatch(&self, session_id: u64, data: Bytes, address: Address) {
464 if let Some(tx) = self.dispatch_map.get(&session_id) {
465 if tx.send((data, address)).await.is_err() {
468 self.dispatch_map.remove(&session_id);
469 }
470 } else {
471 warn!(
472 "[Session][{}] Received datagram for UNKNOWN or CLOSED",
473 session_id
474 );
475 }
476 }
477
478 pub(crate) fn register_session(
480 &self,
481 session_id: u64,
482 ) -> mpsc::Receiver<(Bytes, Address)> {
483 let (tx, rx) = mpsc::channel(128); self.dispatch_map.insert(session_id, tx);
485 rx
486 }
487
488 pub(crate) fn unregister_session(&self, session_id: u64) {
490 self.dispatch_map.remove(&session_id);
491 }
492
493 pub(crate) fn fragment_id_counter(&self) -> &AtomicU32 {
495 &self.fragment_id_counter
496 }
497 }
498 }
499
500 pub mod session {
502 use super::*;
503 use crate::client::datagram::ClientInner;
504 use tokio::sync::mpsc;
505
506 pub struct UdpSession<T, C>
511 where
512 T: Initiator<Connection = C>,
513 C: Connection,
514 {
515 session_id: u64,
516 client_inner: Arc<ClientInner<T, C>>,
517 receiver: mpsc::Receiver<(Bytes, Address)>,
518 }
519
520 impl<T, C> UdpSession<T, C>
521 where
522 T: Initiator<Connection = C>,
523 C: Connection,
524 {
525 pub(crate) fn new(
527 session_id: u64,
528 client_inner: Arc<ClientInner<T, C>>,
529 receiver: mpsc::Receiver<(Bytes, Address)>,
530 ) -> Self {
531 Self {
532 session_id,
533 client_inner,
534 receiver,
535 }
536 }
537
538 pub async fn send_to(&self, data: Bytes, dest_addr: Address) -> io::Result<()> {
540 send_datagram(
541 &self.client_inner,
542 self.session_id,
543 dest_addr,
544 data,
545 self.client_inner.udp_dispatcher.fragment_id_counter(),
546 )
547 .await
548 }
549
550 pub async fn recv_from(&mut self) -> Option<(Bytes, Address)> {
554 self.receiver.recv().await
555 }
556 }
557
558 impl<T, C> Drop for UdpSession<T, C>
559 where
560 T: Initiator<Connection = C>,
561 C: Connection,
562 {
563 fn drop(&mut self) {
564 self.client_inner
567 .udp_dispatcher
568 .unregister_session(self.session_id);
569 }
570 }
571 }
572}