1use std::future::Future;
2use std::io;
3use std::sync::Arc;
4#[cfg(feature = "datagram")]
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use arc_swap::{ArcSwap, Guard};
9use bytes::Bytes;
10use futures::{SinkExt, StreamExt};
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Mutex;
13use tokio_util::codec::Framed;
14#[cfg(feature = "datagram")]
15use tokio_util::sync::CancellationToken;
16
17use ombrac::codec::{UpstreamMessage, length_codec};
18use ombrac::protocol::{
19 self, Address, ClientConnect, ClientHello, HandshakeError, PROTOCOLS_VERSION, Secret,
20 ServerHandshakeResponse,
21};
22use ombrac_macros::{error, info, warn};
23use ombrac_transport::{Connection, Initiator};
24
25#[cfg(feature = "datagram")]
26use datagram::dispatcher::UdpDispatcher;
27#[cfg(feature = "datagram")]
28pub use datagram::session::UdpSession;
29
30pub struct Client<T, C> {
36 inner: Arc<ClientInner<T, C>>,
38 #[cfg(feature = "datagram")]
40 _dispatcher_handle: tokio::task::JoinHandle<()>,
41}
42
43pub(crate) struct ClientInner<T, C> {
48 pub(crate) transport: T,
49 pub(crate) connection: ArcSwap<C>,
50 reconnect_lock: Mutex<()>,
52 secret: Secret,
53 options: Bytes,
54 #[cfg(feature = "datagram")]
55 session_id_counter: AtomicU64,
56 #[cfg(feature = "datagram")]
57 pub(crate) udp_dispatcher: UdpDispatcher,
58 #[cfg(feature = "datagram")]
60 pub(crate) shutdown_token: CancellationToken,
61}
62
63impl<T, C> Client<T, C>
64where
65 T: Initiator<Connection = C>,
66 C: Connection,
67{
68 pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
73 let options = options.unwrap_or_default();
74 let connection = handshake(&transport, secret, options.clone()).await?;
75
76 let inner = Arc::new(ClientInner {
77 transport,
78 connection: ArcSwap::new(Arc::new(connection)),
79 reconnect_lock: Mutex::new(()),
80 secret,
81 options,
82 #[cfg(feature = "datagram")]
83 session_id_counter: AtomicU64::new(1),
84 #[cfg(feature = "datagram")]
85 udp_dispatcher: UdpDispatcher::new(),
86 #[cfg(feature = "datagram")]
87 shutdown_token: CancellationToken::new(),
88 });
89
90 #[cfg(feature = "datagram")]
92 let dispatcher_handle = tokio::spawn(UdpDispatcher::run(Arc::clone(&inner)));
93
94 Ok(Self {
95 inner,
96 #[cfg(feature = "datagram")]
97 _dispatcher_handle: dispatcher_handle,
98 })
99 }
100
101 #[cfg(feature = "datagram")]
106 pub fn open_associate(&self) -> UdpSession<T, C> {
107 let session_id = self.inner.new_session_id();
108 info!(
109 "[Client] New UDP session created with session_id={}",
110 session_id
111 );
112 let receiver = self.inner.udp_dispatcher.register_session(session_id);
113
114 UdpSession::new(session_id, Arc::clone(&self.inner), receiver)
115 }
116
117 pub async fn open_bidirectional(&self, dest_addr: Address) -> io::Result<C::Stream> {
122 let mut stream = self
123 .inner
124 .with_retry(|conn| async move { conn.open_bidirectional().await })
125 .await?;
126
127 let connect_message = UpstreamMessage::Connect(ClientConnect { address: dest_addr });
128 let encoded_bytes = protocol::encode(&connect_message)?;
129
130 stream.write_u32(encoded_bytes.len() as u32).await?;
132 stream.write_all(&encoded_bytes).await?;
133
134 Ok(stream)
135 }
136
137 pub async fn rebind(&self) -> io::Result<()> {
139 self.inner.transport.rebind().await
140 }
141}
142
143impl<T, C> Drop for Client<T, C> {
144 fn drop(&mut self) {
145 #[cfg(feature = "datagram")]
147 self.inner.shutdown_token.cancel();
148 }
149}
150
151impl<T, C> ClientInner<T, C>
154where
155 T: Initiator<Connection = C>,
156 C: Connection,
157{
158 #[cfg(feature = "datagram")]
160 pub(crate) fn new_session_id(&self) -> u64 {
161 self.session_id_counter.fetch_add(1, Ordering::Relaxed)
162 }
163
164 pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
170 where
171 F: Fn(Guard<Arc<C>>) -> Fut,
172 Fut: Future<Output = io::Result<R>>,
173 {
174 let connection = self.connection.load();
175 let old_conn_id = Arc::as_ptr(&connection) as usize;
177
178 match operation(connection).await {
179 Ok(result) => Ok(result),
180 Err(e) if is_connection_error(&e) => {
181 warn!(
182 "Connection error detected: {}. Attempting to reconnect...",
183 e
184 );
185 self.reconnect(old_conn_id).await?;
186 let new_connection = self.connection.load();
187 operation(new_connection).await
188 }
189 Err(e) => Err(e),
190 }
191 }
192
193 async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
197 let _lock = self.reconnect_lock.lock().await;
198
199 let current_conn_id = Arc::as_ptr(&self.connection.load()) as usize;
200 if current_conn_id == old_conn_id {
202 self.transport.rebind().await?;
204
205 let new_connection =
206 handshake(&self.transport, self.secret, self.options.clone()).await?;
207 self.connection.store(Arc::new(new_connection));
208 info!("Reconnection successful");
209 }
210
211 Ok(())
212 }
213}
214
215async fn handshake<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
217where
218 T: Initiator<Connection = C>,
219 C: Connection,
220{
221 const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
222
223 let do_handshake = async {
224 let connection = transport.connect().await?;
225 let mut stream = connection.open_bidirectional().await?;
226
227 let hello_message = UpstreamMessage::Hello(ClientHello {
228 version: PROTOCOLS_VERSION,
229 secret,
230 options,
231 });
232
233 let encoded_bytes = protocol::encode(&hello_message)?;
234 let mut framed = Framed::new(&mut stream, length_codec());
235
236 framed.send(encoded_bytes).await?;
237
238 match framed.next().await {
239 Some(Ok(payload)) => {
240 let response: ServerHandshakeResponse = protocol::decode(&payload)?;
241 match response {
242 ServerHandshakeResponse::Ok => {
243 info!("Handshake with server successful");
244 stream.shutdown().await?;
245 Ok(connection)
246 }
247 ServerHandshakeResponse::Err(e) => {
248 error!("Handshake failed: {:?}", e);
249 let err_kind = match e {
250 HandshakeError::InvalidSecret => io::ErrorKind::PermissionDenied,
251 _ => io::ErrorKind::InvalidData,
252 };
253 Err(io::Error::new(
254 err_kind,
255 format!("Server rejected handshake: {:?}", e),
256 ))
257 }
258 }
259 }
260 Some(Err(e)) => Err(e),
261 None => Err(io::Error::new(
262 io::ErrorKind::UnexpectedEof,
263 "Connection closed by server during handshake",
264 )),
265 }
266 };
267
268 match tokio::time::timeout(HANDSHAKE_TIMEOUT, do_handshake).await {
269 Ok(result) => result,
270 Err(_) => Err(io::Error::new(
271 io::ErrorKind::TimedOut,
272 "Client hello timed out",
273 )),
274 }
275}
276
277fn is_connection_error(e: &io::Error) -> bool {
279 matches!(
280 e.kind(),
281 io::ErrorKind::ConnectionReset
282 | io::ErrorKind::BrokenPipe
283 | io::ErrorKind::NotConnected
284 | io::ErrorKind::TimedOut
285 | io::ErrorKind::UnexpectedEof
286 | io::ErrorKind::NetworkUnreachable
287 )
288}
289
290#[cfg(feature = "datagram")]
293mod datagram {
294 use std::io;
295 use std::sync::Arc;
296 use std::sync::atomic::{AtomicU32, Ordering};
297
298 use bytes::Bytes;
299 use ombrac::protocol::{Address, UdpPacket};
300 use ombrac::reassembly::UdpReassembler;
301 use ombrac_macros::{debug, warn};
302 use ombrac_transport::{Connection, Initiator};
303
304 use super::ClientInner;
305
306 pub(crate) async fn send_datagram<T, C>(
308 inner: &ClientInner<T, C>,
309 session_id: u64,
310 dest_addr: Address,
311 data: Bytes,
312 fragment_id_counter: &AtomicU32,
313 ) -> io::Result<()>
314 where
315 T: Initiator<Connection = C>,
316 C: Connection,
317 {
318 if data.is_empty() {
319 return Ok(());
320 }
321
322 let connection = inner.connection.load();
323 let max_datagram_size = connection.max_datagram_size().unwrap_or(1350);
325 let overhead = UdpPacket::fragmented_overhead();
327 let max_payload_size = max_datagram_size.saturating_sub(overhead).max(1);
328
329 if data.len() <= max_payload_size {
330 let packet = UdpPacket::Unfragmented {
331 session_id,
332 address: dest_addr.clone(),
333 data,
334 };
335 let encoded = packet.encode()?;
336 inner
337 .with_retry(|conn| {
338 let data_for_attempt = encoded.clone();
339 async move { conn.send_datagram(data_for_attempt).await }
340 })
341 .await?;
342 } else {
343 debug!(
345 "[Session][{}] Sending packet for {} is too large ({} > max {}), fragmenting...",
346 session_id,
347 dest_addr,
348 data.len(),
349 max_payload_size
350 );
351
352 let fragment_id = fragment_id_counter.fetch_add(1, Ordering::Relaxed);
353 let fragments =
354 UdpPacket::split_packet(session_id, dest_addr, data, max_payload_size, fragment_id);
355
356 for fragment in fragments {
357 let packet_bytes = fragment.encode()?;
358 inner
359 .with_retry(|conn| {
360 let data_for_attempt = packet_bytes.clone();
361 async move { conn.send_datagram(data_for_attempt).await }
362 })
363 .await?;
364 }
365 }
366 Ok(())
367 }
368
369 pub(crate) async fn read_datagram<T, C>(
371 inner: &ClientInner<T, C>,
372 reassembler: &mut UdpReassembler,
373 ) -> io::Result<(u64, Address, Bytes)>
374 where
375 T: Initiator<Connection = C>,
376 C: Connection,
377 {
378 loop {
379 let packet_bytes = inner
380 .with_retry(|conn| async move { conn.read_datagram().await })
381 .await?;
382
383 let packet = match UdpPacket::decode(&packet_bytes) {
384 Ok(packet) => packet,
385 Err(_e) => {
386 warn!("Failed to decode UDP packet: {}. Discarding.", _e);
387 continue; }
389 };
390
391 match reassembler.process(packet).await {
392 Ok(Some((session_id, address, data))) => {
393 return Ok((session_id, address, data));
394 }
395 Ok(None) => {
396 continue; }
398 Err(_e) => {
399 warn!("Reassembly error: {}. Discarding fragment.", _e);
400 continue; }
402 }
403 }
404 }
405
406 pub(crate) mod dispatcher {
408 use super::*;
409 use dashmap::DashMap;
410 use tokio::sync::mpsc;
411
412 type UdpSessionSender = mpsc::Sender<(Bytes, Address)>;
413
414 pub(crate) struct UdpDispatcher {
416 dispatch_map: DashMap<u64, UdpSessionSender>,
418 fragment_id_counter: AtomicU32,
419 }
420
421 impl UdpDispatcher {
422 pub(crate) fn new() -> Self {
423 Self {
424 dispatch_map: DashMap::new(),
425 fragment_id_counter: AtomicU32::new(0),
426 }
427 }
428
429 pub(crate) async fn run<T, C>(inner: Arc<ClientInner<T, C>>)
434 where
435 T: Initiator<Connection = C>,
436 C: Connection,
437 {
438 let mut reassembler = UdpReassembler::default();
439
440 loop {
441 tokio::select! {
442 _ = inner.shutdown_token.cancelled() => {
444 break;
445 }
446 result = read_datagram(&inner, &mut reassembler) => {
448 match result {
449 Ok((session_id, address, data)) => {
450 inner.udp_dispatcher.dispatch(session_id, data, address).await;
451 }
452 Err(_e) => {
453 warn!("Error reading datagram: {}. Retrying after delay...", _e);
454 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
456 }
457 }
458 }
459 }
460 }
461 }
462
463 async fn dispatch(&self, session_id: u64, data: Bytes, address: Address) {
465 if let Some(tx) = self.dispatch_map.get(&session_id) {
466 if tx.send((data, address)).await.is_err() {
469 self.dispatch_map.remove(&session_id);
470 }
471 } else {
472 warn!(
473 "[Session][{}] Received datagram for UNKNOWN or CLOSED",
474 session_id
475 );
476 }
477 }
478
479 pub(crate) fn register_session(
481 &self,
482 session_id: u64,
483 ) -> mpsc::Receiver<(Bytes, Address)> {
484 let (tx, rx) = mpsc::channel(128); self.dispatch_map.insert(session_id, tx);
486 rx
487 }
488
489 pub(crate) fn unregister_session(&self, session_id: u64) {
491 self.dispatch_map.remove(&session_id);
492 }
493
494 pub(crate) fn fragment_id_counter(&self) -> &AtomicU32 {
496 &self.fragment_id_counter
497 }
498 }
499 }
500
501 pub mod session {
503 use super::*;
504 use crate::client::datagram::ClientInner;
505 use tokio::sync::mpsc;
506
507 pub struct UdpSession<T, C>
512 where
513 T: Initiator<Connection = C>,
514 C: Connection,
515 {
516 session_id: u64,
517 client_inner: Arc<ClientInner<T, C>>,
518 receiver: mpsc::Receiver<(Bytes, Address)>,
519 }
520
521 impl<T, C> UdpSession<T, C>
522 where
523 T: Initiator<Connection = C>,
524 C: Connection,
525 {
526 pub(crate) fn new(
528 session_id: u64,
529 client_inner: Arc<ClientInner<T, C>>,
530 receiver: mpsc::Receiver<(Bytes, Address)>,
531 ) -> Self {
532 Self {
533 session_id,
534 client_inner,
535 receiver,
536 }
537 }
538
539 pub async fn send_to(&self, data: Bytes, dest_addr: Address) -> io::Result<()> {
541 send_datagram(
542 &self.client_inner,
543 self.session_id,
544 dest_addr,
545 data,
546 self.client_inner.udp_dispatcher.fragment_id_counter(),
547 )
548 .await
549 }
550
551 pub async fn recv_from(&mut self) -> Option<(Bytes, Address)> {
555 self.receiver.recv().await
556 }
557 }
558
559 impl<T, C> Drop for UdpSession<T, C>
560 where
561 T: Initiator<Connection = C>,
562 C: Connection,
563 {
564 fn drop(&mut self) {
565 self.client_inner
568 .udp_dispatcher
569 .unregister_session(self.session_id);
570 }
571 }
572 }
573}