1use std::future::Future;
2use std::io;
3use std::sync::Arc;
4#[cfg(feature = "datagram")]
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use arc_swap::{ArcSwap, Guard};
9use bytes::Bytes;
10use futures::{SinkExt, StreamExt};
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Mutex;
13use tokio::time::Instant;
14use tokio_util::codec::Framed;
15#[cfg(feature = "datagram")]
16use tokio_util::sync::CancellationToken;
17
18use ombrac::codec::{UpstreamMessage, length_codec};
19use ombrac::protocol::{
20 self, Address, ClientConnect, ClientHello, HandshakeError, PROTOCOLS_VERSION, Secret,
21 ServerHandshakeResponse,
22};
23use ombrac_macros::{error, info, warn};
24use ombrac_transport::{Connection, Initiator};
25
26#[cfg(feature = "datagram")]
27use datagram::dispatcher::UdpDispatcher;
28#[cfg(feature = "datagram")]
29pub use datagram::session::UdpSession;
30
31struct ReconnectState {
32 last_attempt: Option<Instant>,
33 backoff: Duration,
34}
35
36impl Default for ReconnectState {
37 fn default() -> Self {
38 Self {
39 last_attempt: None,
40 backoff: Duration::from_secs(1),
41 }
42 }
43}
44
45pub struct Client<T, C> {
51 inner: Arc<ClientInner<T, C>>,
53 #[cfg(feature = "datagram")]
55 _dispatcher_handle: tokio::task::JoinHandle<()>,
56}
57
58pub(crate) struct ClientInner<T, C> {
63 pub(crate) transport: T,
64 pub(crate) connection: ArcSwap<C>,
65 reconnect_lock: Mutex<ReconnectState>,
67 secret: Secret,
68 options: Bytes,
69 #[cfg(feature = "datagram")]
70 session_id_counter: AtomicU64,
71 #[cfg(feature = "datagram")]
72 pub(crate) udp_dispatcher: UdpDispatcher,
73 #[cfg(feature = "datagram")]
75 pub(crate) shutdown_token: CancellationToken,
76}
77
78impl<T, C> Client<T, C>
79where
80 T: Initiator<Connection = C>,
81 C: Connection,
82{
83 pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
88 let options = options.unwrap_or_default();
89 let connection = handshake(&transport, secret, options.clone()).await?;
90
91 let inner = Arc::new(ClientInner {
92 transport,
93 connection: ArcSwap::new(Arc::new(connection)),
94 reconnect_lock: Mutex::new(ReconnectState::default()),
95 secret,
96 options,
97 #[cfg(feature = "datagram")]
98 session_id_counter: AtomicU64::new(1),
99 #[cfg(feature = "datagram")]
100 udp_dispatcher: UdpDispatcher::new(),
101 #[cfg(feature = "datagram")]
102 shutdown_token: CancellationToken::new(),
103 });
104
105 #[cfg(feature = "datagram")]
107 let dispatcher_handle = tokio::spawn(UdpDispatcher::run(Arc::clone(&inner)));
108
109 Ok(Self {
110 inner,
111 #[cfg(feature = "datagram")]
112 _dispatcher_handle: dispatcher_handle,
113 })
114 }
115
116 #[cfg(feature = "datagram")]
121 pub fn open_associate(&self) -> UdpSession<T, C> {
122 let session_id = self.inner.new_session_id();
123 info!(
124 "[Client] New UDP session created with session_id={}",
125 session_id
126 );
127 let receiver = self.inner.udp_dispatcher.register_session(session_id);
128
129 UdpSession::new(session_id, Arc::clone(&self.inner), receiver)
130 }
131
132 pub async fn open_bidirectional(&self, dest_addr: Address) -> io::Result<C::Stream> {
137 let mut stream = self
138 .inner
139 .with_retry(|conn| async move { conn.open_bidirectional().await })
140 .await?;
141
142 let connect_message = UpstreamMessage::Connect(ClientConnect { address: dest_addr });
143 let encoded_bytes = protocol::encode(&connect_message)?;
144
145 stream.write_u32(encoded_bytes.len() as u32).await?;
147 stream.write_all(&encoded_bytes).await?;
148
149 Ok(stream)
150 }
151
152 pub async fn rebind(&self) -> io::Result<()> {
154 self.inner.transport.rebind().await
155 }
156}
157
158impl<T, C> Drop for Client<T, C> {
159 fn drop(&mut self) {
160 #[cfg(feature = "datagram")]
162 self.inner.shutdown_token.cancel();
163 }
164}
165
166impl<T, C> ClientInner<T, C>
169where
170 T: Initiator<Connection = C>,
171 C: Connection,
172{
173 #[cfg(feature = "datagram")]
175 pub(crate) fn new_session_id(&self) -> u64 {
176 self.session_id_counter.fetch_add(1, Ordering::Relaxed)
177 }
178
179 pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
185 where
186 F: Fn(Guard<Arc<C>>) -> Fut,
187 Fut: Future<Output = io::Result<R>>,
188 {
189 let connection = self.connection.load();
190 let old_conn_id = Arc::as_ptr(&connection) as usize;
192
193 match operation(connection).await {
194 Ok(result) => Ok(result),
195 Err(e) if is_connection_error(&e) => {
196 warn!(
197 "Connection error detected: {}. Attempting to reconnect...",
198 e
199 );
200 self.reconnect(old_conn_id).await?;
201 let new_connection = self.connection.load();
202 operation(new_connection).await
203 }
204 Err(e) => Err(e),
205 }
206 }
207
208 async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
212 let mut state = self.reconnect_lock.lock().await;
213
214 let current_conn = self.connection.load();
215 let current_conn_id = Arc::as_ptr(¤t_conn) as usize;
216 if current_conn_id != old_conn_id {
217 return Ok(());
218 }
219
220 if let Some(last) = state.last_attempt {
221 let elapsed = last.elapsed();
222 if elapsed < state.backoff {
223 let wait_time = state.backoff - elapsed;
224 warn!("Too many reconnect attempts. Global throttling for {:?}...", wait_time);
225
226 drop(state);
227 tokio::time::sleep(wait_time).await;
228 return Err(io::Error::new(io::ErrorKind::Other, "Reconnect throttled"));
229 }
230 }
231
232 info!("Attemping to reconnect...");
233
234 state.last_attempt = Some(Instant::now());
235
236 if let Err(e) = self.transport.rebind().await {
237 state.backoff = (state.backoff * 2).min(Duration::from_secs(60));
238 return Err(e);
239 }
240
241 match handshake(&self.transport, self.secret, self.options.clone()).await {
242 Ok(new_connection) => {
243 state.backoff = Duration::from_secs(1);
244 state.last_attempt = None;
245
246 self.connection.store(Arc::new(new_connection));
247 info!("Reconnection successful");
248 Ok(())
249 }
250 Err(e) => {
251 state.backoff = (state.backoff * 2).min(Duration::from_secs(60));
252 error!("Reconnect failed: {}. Next retry backoff: {:?}", e, state.backoff);
253 Err(e)
254 }
255 }
256 }
257}
258
259async fn handshake<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
261where
262 T: Initiator<Connection = C>,
263 C: Connection,
264{
265 const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
266
267 let do_handshake = async {
268 let connection = transport.connect().await?;
269 let mut stream = connection.open_bidirectional().await?;
270
271 let hello_message = UpstreamMessage::Hello(ClientHello {
272 version: PROTOCOLS_VERSION,
273 secret,
274 options,
275 });
276
277 let encoded_bytes = protocol::encode(&hello_message)?;
278 let mut framed = Framed::new(&mut stream, length_codec());
279
280 framed.send(encoded_bytes).await?;
281
282 match framed.next().await {
283 Some(Ok(payload)) => {
284 let response: ServerHandshakeResponse = protocol::decode(&payload)?;
285 match response {
286 ServerHandshakeResponse::Ok => {
287 info!("Handshake with server successful");
288 stream.shutdown().await?;
289 Ok(connection)
290 }
291 ServerHandshakeResponse::Err(e) => {
292 error!("Handshake failed: {:?}", e);
293 let err_kind = match e {
294 HandshakeError::InvalidSecret => io::ErrorKind::PermissionDenied,
295 _ => io::ErrorKind::InvalidData,
296 };
297 Err(io::Error::new(
298 err_kind,
299 format!("Server rejected handshake: {:?}", e),
300 ))
301 }
302 }
303 }
304 Some(Err(e)) => Err(e),
305 None => Err(io::Error::new(
306 io::ErrorKind::UnexpectedEof,
307 "Connection closed by server during handshake",
308 )),
309 }
310 };
311
312 match tokio::time::timeout(HANDSHAKE_TIMEOUT, do_handshake).await {
313 Ok(result) => result,
314 Err(_) => Err(io::Error::new(
315 io::ErrorKind::TimedOut,
316 "Client hello timed out",
317 )),
318 }
319}
320
321fn is_connection_error(e: &io::Error) -> bool {
323 matches!(
324 e.kind(),
325 io::ErrorKind::ConnectionReset
326 | io::ErrorKind::BrokenPipe
327 | io::ErrorKind::NotConnected
328 | io::ErrorKind::TimedOut
329 | io::ErrorKind::UnexpectedEof
330 | io::ErrorKind::NetworkUnreachable
331 )
332}
333
334#[cfg(feature = "datagram")]
337mod datagram {
338 use std::io;
339 use std::sync::Arc;
340 use std::sync::atomic::{AtomicU32, Ordering};
341
342 use bytes::Bytes;
343 use ombrac::protocol::{Address, UdpPacket};
344 use ombrac::reassembly::UdpReassembler;
345 use ombrac_macros::{debug, warn};
346 use ombrac_transport::{Connection, Initiator};
347
348 use super::ClientInner;
349
350 pub(crate) async fn send_datagram<T, C>(
352 inner: &ClientInner<T, C>,
353 session_id: u64,
354 dest_addr: Address,
355 data: Bytes,
356 fragment_id_counter: &AtomicU32,
357 ) -> io::Result<()>
358 where
359 T: Initiator<Connection = C>,
360 C: Connection,
361 {
362 if data.is_empty() {
363 return Ok(());
364 }
365
366 let connection = inner.connection.load();
367 let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
369 let overhead = UdpPacket::fragmented_overhead();
371 let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
372
373 if data.len() <= max_payload_size {
374 let packet = UdpPacket::Unfragmented {
375 session_id,
376 address: dest_addr.clone(),
377 data,
378 };
379 let encoded = packet.encode()?;
380 inner
381 .with_retry(|conn| {
382 let data_for_attempt = encoded.clone();
383 async move { conn.send_datagram(data_for_attempt).await }
384 })
385 .await?;
386 } else {
387 debug!(
389 "[Session][{}] Sending packet for {} is too large ({} > max {}), fragmenting...",
390 session_id,
391 dest_addr,
392 data.len(),
393 max_payload_size
394 );
395
396 let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
397 let fragments =
398 UdpPacket::split_packet(session_id, dest_addr, data, max_payload_size, fragment_id);
399
400 for fragment in fragments {
401 let packet_bytes = fragment.encode()?;
402 inner
403 .with_retry(|conn| {
404 let data_for_attempt = packet_bytes.clone();
405 async move { conn.send_datagram(data_for_attempt).await }
406 })
407 .await?;
408 }
409 }
410 Ok(())
411 }
412
413 pub(crate) async fn read_datagram<T, C>(
415 inner: &ClientInner<T, C>,
416 reassembler: &mut UdpReassembler,
417 ) -> io::Result<(u64, Address, Bytes)>
418 where
419 T: Initiator<Connection = C>,
420 C: Connection,
421 {
422 loop {
423 let packet_bytes = inner
424 .with_retry(|conn| async move { conn.read_datagram().await })
425 .await?;
426
427 let packet = match UdpPacket::decode(&packet_bytes) {
428 Ok(packet) => packet,
429 Err(_e) => {
430 warn!("Failed to decode UDP packet: {}. Discarding.", _e);
431 continue; }
433 };
434
435 match reassembler.process(packet).await {
436 Ok(Some((session_id, address, data))) => {
437 return Ok((session_id, address, data));
438 }
439 Ok(None) => {
440 continue; }
442 Err(_e) => {
443 warn!("Reassembly error: {}. Discarding fragment.", _e);
444 continue; }
446 }
447 }
448 }
449
450 pub(crate) mod dispatcher {
452 use std::time::Duration;
453
454 use super::*;
455 use dashmap::DashMap;
456 use tokio::sync::mpsc;
457
458 type UdpSessionSender = mpsc::Sender<(Bytes, Address)>;
459
460 pub(crate) struct UdpDispatcher {
462 dispatch_map: DashMap<u64, UdpSessionSender>,
464 fragment_id_counter: AtomicU32,
465 }
466
467 impl UdpDispatcher {
468 pub(crate) fn new() -> Self {
469 Self {
470 dispatch_map: DashMap::new(),
471 fragment_id_counter: AtomicU32::new(0),
472 }
473 }
474
475 pub(crate) async fn run<T, C>(inner: Arc<ClientInner<T, C>>)
480 where
481 T: Initiator<Connection = C>,
482 C: Connection,
483 {
484 let mut reassembler = UdpReassembler::default();
485 const INITIAL_DELAY: Duration = Duration::from_secs(1);
486 const MAX_DELAY: Duration = Duration::from_secs(60);
487 let mut current_delay = INITIAL_DELAY;
488
489 loop {
490 tokio::select! {
491 _ = inner.shutdown_token.cancelled() => {
493 break;
494 }
495 result = read_datagram(&inner, &mut reassembler) => {
497 match result {
498 Ok((session_id, address, data)) => {
499 if current_delay != INITIAL_DELAY {
500 current_delay = INITIAL_DELAY;
501 }
502
503 inner.udp_dispatcher.dispatch(session_id, data, address).await;
504 }
505 Err(_e) => {
506 warn!("Error reading datagram: {}. Retrying in {:?}...", _e, current_delay);
507 tokio::time::sleep(current_delay).await;
508 current_delay = (current_delay * 2).min(MAX_DELAY);
509 }
510 }
511 }
512 }
513 }
514 }
515
516 async fn dispatch(&self, session_id: u64, data: Bytes, address: Address) {
518 if let Some(tx) = self.dispatch_map.get(&session_id) {
519 if tx.send((data, address)).await.is_err() {
522 self.dispatch_map.remove(&session_id);
523 }
524 } else {
525 warn!(
526 "[Session][{}] Received datagram for UNKNOWN or CLOSED",
527 session_id
528 );
529 }
530 }
531
532 pub(crate) fn register_session(
534 &self,
535 session_id: u64,
536 ) -> mpsc::Receiver<(Bytes, Address)> {
537 let (tx, rx) = mpsc::channel(128); self.dispatch_map.insert(session_id, tx);
539 rx
540 }
541
542 pub(crate) fn unregister_session(&self, session_id: u64) {
544 self.dispatch_map.remove(&session_id);
545 }
546
547 pub(crate) fn fragment_id_counter(&self) -> &AtomicU32 {
549 &self.fragment_id_counter
550 }
551 }
552 }
553
554 pub mod session {
556 use super::*;
557 use crate::client::datagram::ClientInner;
558 use tokio::sync::mpsc;
559
560 pub struct UdpSession<T, C>
565 where
566 T: Initiator<Connection = C>,
567 C: Connection,
568 {
569 session_id: u64,
570 client_inner: Arc<ClientInner<T, C>>,
571 receiver: mpsc::Receiver<(Bytes, Address)>,
572 }
573
574 impl<T, C> UdpSession<T, C>
575 where
576 T: Initiator<Connection = C>,
577 C: Connection,
578 {
579 pub(crate) fn new(
581 session_id: u64,
582 client_inner: Arc<ClientInner<T, C>>,
583 receiver: mpsc::Receiver<(Bytes, Address)>,
584 ) -> Self {
585 Self {
586 session_id,
587 client_inner,
588 receiver,
589 }
590 }
591
592 pub async fn send_to(&self, data: Bytes, dest_addr: Address) -> io::Result<()> {
594 send_datagram(
595 &self.client_inner,
596 self.session_id,
597 dest_addr,
598 data,
599 self.client_inner.udp_dispatcher.fragment_id_counter(),
600 )
601 .await
602 }
603
604 pub async fn recv_from(&mut self) -> Option<(Bytes, Address)> {
608 self.receiver.recv().await
609 }
610 }
611
612 impl<T, C> Drop for UdpSession<T, C>
613 where
614 T: Initiator<Connection = C>,
615 C: Connection,
616 {
617 fn drop(&mut self) {
618 self.client_inner
621 .udp_dispatcher
622 .unregister_session(self.session_id);
623 }
624 }
625 }
626}