1use std::sync::{
2 atomic::{AtomicBool, AtomicU32},
3 Arc,
4};
5
6use anyhow::Context;
7use bytes::{Bytes, BytesMut};
8use dashmap::DashMap;
9use parking_lot::Mutex;
10use tokio::{select, sync::Notify, task::JoinSet, time::timeout};
11use tracing::Instrument;
12
13use crate::{
14 error::Error,
15 ffi_safe::{Kcp, KcpConfig},
16 packet_def::KcpPacket,
17 state::{KcpConnectionFSM, PacketHeaderFlagManipulator},
18};
19
20pub type Sender<T> = tokio::sync::mpsc::Sender<T>;
21pub type Receiver<T> = tokio::sync::mpsc::Receiver<T>;
22
23pub type KcpPakcetSender = Sender<KcpPacket>;
24pub type KcpPacketReceiver = Receiver<KcpPacket>;
25
26pub type KcpStreamSender = Sender<BytesMut>;
27pub type KcpStreamReceiver = Receiver<BytesMut>;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct ConnId {
31 conv: u32,
32 src_session_id: u32,
33 dst_session_id: u32,
34}
35
36impl From<&KcpPacket> for ConnId {
37 fn from(packet: &KcpPacket) -> Self {
38 Self {
39 conv: packet.header().conv(),
40 src_session_id: packet.header().src_session_id(),
41 dst_session_id: packet.header().dst_session_id(),
42 }
43 }
44}
45
46impl ConnId {
47 fn fill_packet_header(&self, packet: &mut KcpPacket) {
48 packet
49 .mut_header()
50 .set_conv(self.conv)
51 .set_src_session_id(self.src_session_id)
52 .set_dst_session_id(self.dst_session_id);
53 }
54}
55
56struct KcpConnectionInner {
57 update_notifier: Notify,
58 recv_notifier: Notify,
59 send_notifier: Notify,
60
61 has_new_input: AtomicBool,
62 waiting_new_send_window: AtomicBool,
63}
64
65struct KcpConnection {
66 conn_id: ConnId,
67 kcp: Arc<Mutex<Box<Kcp>>>,
68
69 inner: Arc<KcpConnectionInner>,
70
71 send_sender: Option<Sender<BytesMut>>,
72 send_receiver: Option<Receiver<BytesMut>>,
73
74 recv_sender: Option<Sender<BytesMut>>,
75 recv_receiver: Option<Receiver<BytesMut>>,
76
77 send_close_notifier: Arc<Notify>,
78 recv_closed: Arc<AtomicBool>,
79
80 tasks: JoinSet<()>,
81}
82
83impl KcpConnection {
84 pub fn new(conn_id: ConnId) -> Result<Self, Error> {
85 let kcp = Kcp::new(KcpConfig::new_turbo(conn_id.conv))?;
86
87 let (send_sender, send_receiver) = tokio::sync::mpsc::channel(128);
88 let (recv_sender, recv_receiver) = tokio::sync::mpsc::channel(128);
89
90 Ok(Self {
91 conn_id,
92 kcp: Arc::new(Mutex::new(kcp)),
93
94 inner: Arc::new(KcpConnectionInner {
95 update_notifier: Notify::new(),
96 recv_notifier: Notify::new(),
97 send_notifier: Notify::new(),
98
99 has_new_input: AtomicBool::new(false),
100 waiting_new_send_window: AtomicBool::new(false),
101 }),
102
103 send_sender: Some(send_sender),
104 send_receiver: Some(send_receiver),
105
106 recv_sender: Some(recv_sender),
107 recv_receiver: Some(recv_receiver),
108
109 send_close_notifier: Arc::new(Notify::new()),
110 recv_closed: Arc::new(AtomicBool::new(false)),
111
112 tasks: JoinSet::new(),
113 })
114 }
115
116 pub fn run(&mut self, output_sender: KcpPakcetSender) {
117 let conn_id = self.conn_id;
118 self.kcp
119 .lock()
120 .set_output_cb(Box::new(move |conv, data: BytesMut| {
121 let mut kcp_packet = KcpPacket::new_with_payload(&data);
122 conn_id.fill_packet_header(&mut kcp_packet);
123 kcp_packet.mut_header().set_data(true).set_ack(true);
124 tracing::trace!(?conv, "sending output data: {:?}", kcp_packet);
125 if let Err(e) = output_sender.try_send(kcp_packet) {
126 tracing::debug!(?e, ?conn_id, "send output data failed");
127 }
128 Ok(())
129 }));
130
131 let inner = self.inner.clone();
133 let kcp = self.kcp.clone();
134 let recv_closed = self.recv_closed.clone();
135 self.tasks.spawn(async move {
136 loop {
137 let next_update_ms = kcp.lock().next_update_delay_ms();
138 select! {
139 _ = tokio::time::sleep(tokio::time::Duration::from_millis(next_update_ms as u64)) => {}
140 _ = inner.update_notifier.notified() => {}
141 }
142
143 kcp.lock().update();
144
145 if inner.has_new_input.swap(false, std::sync::atomic::Ordering::SeqCst) {
146 inner.recv_notifier.notify_one();
147 }
148
149 if inner.waiting_new_send_window.swap(false, std::sync::atomic::Ordering::SeqCst) {
150 inner.send_notifier.notify_one();
151 }
152
153 if recv_closed.load(std::sync::atomic::Ordering::Relaxed) {
154 inner.recv_notifier.notify_one();
155 }
156 }
157 });
158
159 let kcp = self.kcp.clone();
161 let inner = self.inner.clone();
162 let mut send_receiver = self.send_receiver.take().unwrap();
163 let send_close_notifier = self.send_close_notifier.clone();
164 self.tasks.spawn(
165 async move {
166 while let Some(data) = send_receiver.recv().await {
167 loop {
168 let (waitsnd, sndwnd) = {
169 let kcp = kcp.lock();
170 (kcp.waitsnd(), kcp.sendwnd())
171 };
172 if waitsnd > 2 * sndwnd {
173 inner
174 .waiting_new_send_window
175 .store(true, std::sync::atomic::Ordering::SeqCst);
176 inner.send_notifier.notified().await;
177 } else {
178 break;
179 }
180 }
181 kcp.lock().send(data.freeze()).unwrap();
182 kcp.lock().flush();
183 inner.update_notifier.notify_one();
184 }
185
186 tracing::debug!(
187 ?conn_id,
188 "connection packet sender close, waiting for waitsnd to be 0"
189 );
190
191 while kcp.lock().waitsnd() > 0 {
193 inner
194 .waiting_new_send_window
195 .store(true, std::sync::atomic::Ordering::SeqCst);
196 inner.send_notifier.notified().await;
197 }
198
199 send_close_notifier.notify_one();
200 tracing::debug!(?conn_id, "connection packet send task done");
201 }
202 .instrument(tracing::trace_span!("send_task", conn = ?conn_id)),
203 );
204
205 let kcp = self.kcp.clone();
207 let inner = self.inner.clone();
208 let conn_id = self.conn_id;
209 let recv_sender = self.recv_sender.take().unwrap();
210 let recv_closed = self.recv_closed.clone();
211 self.tasks.spawn(
212 async move {
213 let mut buf = BytesMut::new();
214 while !recv_closed.load(std::sync::atomic::Ordering::Relaxed) {
215 let peeksize = kcp.lock().peeksize();
216 if peeksize <= 0 {
217 tracing::trace!("recv nothing, wait for next update");
218 inner.recv_notifier.notified().await;
219 continue;
220 };
221
222 if buf.capacity() < peeksize as usize {
223 buf.reserve(std::cmp::max(peeksize as usize, 4096));
224 }
225 kcp.lock().recv(&mut buf).unwrap();
226 tracing::trace!("recv data ({}): {:?}", buf.len(), buf);
227 assert_ne!(0, buf.len());
228 let send_ret = recv_sender.send(buf.split()).await;
229 if let Err(_) = send_ret {
230 break;
231 }
232 }
233
234 tracing::debug!(?conn_id, "connection packet recv task done");
235 }
236 .instrument(tracing::trace_span!("recv_task", conn = ?conn_id)),
237 );
238 }
239
240 fn handle_input(&mut self, packet: &KcpPacket) -> Result<(), Error> {
241 self.kcp.lock().handle_input(packet.payload())?;
242 self.inner
243 .has_new_input
244 .store(true, std::sync::atomic::Ordering::SeqCst);
245 self.inner.update_notifier.notify_one();
246 Ok(())
247 }
248
249 fn send_sender(&mut self) -> KcpStreamSender {
250 self.send_sender.take().unwrap()
251 }
252
253 fn recv_receiver(&mut self) -> KcpStreamReceiver {
254 self.recv_receiver.take().unwrap()
255 }
256
257 fn send_close_notifier(&self) -> Arc<Notify> {
258 self.send_close_notifier.clone()
259 }
260
261 fn close_recv(&self) {
262 self.recv_closed
263 .store(true, std::sync::atomic::Ordering::SeqCst);
264 self.inner.recv_notifier.notify_one();
265 }
266}
267
268impl Drop for KcpConnection {
269 fn drop(&mut self) {
270 self.send_close_notifier.notify_one();
271 }
272}
273
274impl PacketHeaderFlagManipulator for KcpPacket {
275 fn has_syn(&self) -> bool {
276 self.header().is_syn()
277 }
278
279 fn has_ack(&self) -> bool {
280 self.header().is_ack()
281 }
282
283 fn has_fin(&self) -> bool {
284 self.header().is_fin()
285 }
286
287 fn has_rst(&self) -> bool {
288 self.header().is_rst()
289 }
290
291 fn has_data(&self) -> bool {
292 self.header().is_data()
293 }
294
295 fn set_syn(&mut self, value: bool) {
296 self.mut_header().set_syn(value);
297 }
298
299 fn set_ack(&mut self, value: bool) {
300 self.mut_header().set_ack(value);
301 }
302
303 fn set_fin(&mut self, value: bool) {
304 self.mut_header().set_fin(value);
305 }
306
307 fn set_rst(&mut self, value: bool) {
308 self.mut_header().set_rst(value);
309 }
310
311 fn set_data(&mut self, value: bool) {
312 self.mut_header().set_data(value);
313 }
314}
315
316struct KcpConnectionState {
317 fsm: KcpConnectionFSM,
318 notify: Arc<Notify>,
319 conn_data: Bytes,
320 last_pong: std::time::Instant,
321}
322
323impl std::fmt::Debug for KcpConnectionState {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("KcpConnectionState")
326 .field("fsm", &self.fsm)
327 .finish()
328 }
329}
330
331impl KcpConnectionState {
332 fn new(fsm: KcpConnectionFSM) -> Self {
333 Self {
334 fsm,
335 notify: Arc::new(Notify::new()),
336 conn_data: Bytes::new(),
337 last_pong: std::time::Instant::now(),
338 }
339 }
340
341 fn handle_packet(&mut self, packet: &KcpPacket) -> Result<Option<KcpPacket>, Error> {
342 self.notify_pong();
343 let mut out_packet = None;
344 let old_state = self.fsm.clone();
345 let _ = self.fsm.handle_packet(packet, &mut out_packet);
346 if old_state != self.fsm {
347 self.notify.notify_one();
348 return Ok(out_packet);
349 }
350 Ok(None)
351 }
352
353 fn notify(&self) -> Arc<Notify> {
354 self.notify.clone()
355 }
356
357 fn is_established(&self) -> bool {
358 matches!(self.fsm, KcpConnectionFSM::Established)
359 }
360
361 fn is_peer_closed(&self) -> bool {
362 matches!(
363 self.fsm,
364 KcpConnectionFSM::PeerClosed | KcpConnectionFSM::Closed
365 )
366 }
367
368 fn is_local_closed(&self) -> bool {
369 matches!(
370 self.fsm,
371 KcpConnectionFSM::LocalClosed | KcpConnectionFSM::Closed
372 )
373 }
374
375 fn is_closed(&self) -> bool {
376 matches!(self.fsm, KcpConnectionFSM::Closed)
377 }
378
379 fn set_data(&mut self, data: Bytes) {
380 self.conn_data = data;
381 }
382
383 fn notify_pong(&mut self) {
384 self.last_pong = std::time::Instant::now();
385 }
386
387 fn is_pong_timeout(&self) -> bool {
388 self.last_pong.elapsed() > std::time::Duration::from_secs(60)
389 }
390}
391
392struct KcpEndpointData {
393 cur_conv: AtomicU32,
394 conn_map: DashMap<ConnId, KcpConnection>,
395 state_map: DashMap<ConnId, KcpConnectionState>,
396}
397
398impl KcpEndpointData {
399 fn new() -> Self {
400 Self {
401 cur_conv: AtomicU32::new(rand::random()),
402 conn_map: DashMap::new(),
403 state_map: DashMap::new(),
404 }
405 }
406}
407
408pub type KcpConfigFactory = Box<dyn Fn(u32) -> KcpConfig + Send + Sync>;
409
410pub struct KcpEndpoint {
411 id: u64,
412 data: Arc<KcpEndpointData>,
413
414 input_sender: KcpPakcetSender,
415 input_receiver: Option<KcpPacketReceiver>,
416
417 output_sender: KcpPakcetSender,
418 output_receiver: Option<KcpPacketReceiver>,
419
420 new_conn_sender: tokio::sync::mpsc::Sender<ConnId>,
421 new_conn_receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<ConnId>>>,
422
423 kcp_config_factory: KcpConfigFactory,
424
425 tasks: JoinSet<()>,
426}
427
428impl std::fmt::Debug for KcpEndpoint {
429 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 f.debug_struct("KcpEndpoint").field("id", &self.id).finish()
431 }
432}
433
434impl KcpEndpoint {
435 pub fn new() -> Self {
436 let (input_sender, input_receiver) = tokio::sync::mpsc::channel(1024);
437 let (output_sender, output_receiver) = tokio::sync::mpsc::channel(1024);
438 let (new_conn_sender, new_conn_receiver) = tokio::sync::mpsc::channel(4);
439
440 Self {
441 id: rand::random(),
442 data: Arc::new(KcpEndpointData::new()),
443
444 input_sender,
445 input_receiver: Some(input_receiver),
446
447 output_sender,
448 output_receiver: Some(output_receiver),
449
450 new_conn_sender,
451 new_conn_receiver: Arc::new(tokio::sync::Mutex::new(new_conn_receiver)),
452
453 kcp_config_factory: Box::new(|conv| KcpConfig::new_turbo(conv)),
454
455 tasks: JoinSet::new(),
456 }
457 }
458
459 pub fn set_kcp_config_factory(&mut self, factory: KcpConfigFactory) {
460 self.kcp_config_factory = factory;
461 }
462
463 async fn try_handle_pingpong(
464 data: &KcpEndpointData,
465 packet: &KcpPacket,
466 output_sender: &KcpPakcetSender,
467 ) -> bool {
468 if !packet.header().is_ping() {
469 return false;
470 }
471
472 if !packet.header().is_pong() {
473 let conn_id = ConnId::from(packet);
474 let need_send_pong = data
475 .state_map
476 .get_mut(&conn_id)
477 .map(|x| !x.is_local_closed())
478 .unwrap_or(false);
479
480 let mut out_packet = packet.clone();
481 if need_send_pong {
482 out_packet.mut_header().set_pong(true);
483 } else {
484 out_packet.mut_header().set_ping(false);
485 out_packet.mut_header().set_rst(true);
486 };
487
488 tracing::trace!("sending pong packet: {:?}", out_packet);
489 let ret = output_sender.send(out_packet).await;
490 if let Err(e) = ret {
491 tracing::error!(?e, "send pong packet failed");
492 }
493 } else {
494 let conv = ConnId::from(packet);
495 if let Some(mut state) = data.state_map.get_mut(&conv) {
496 state.notify_pong();
497 }
498 }
499
500 true
501 }
502
503 pub async fn run(&mut self) {
504 let mut input_receiver = self.input_receiver.take().unwrap();
505 let data = self.data.clone();
506 let output_sender = self.output_sender.clone();
507 let new_conn_sender = self.new_conn_sender.clone();
508
509 self.tasks.spawn(
510 async move {
511 while let Some(packet) = input_receiver.recv().await {
512 tracing::trace!("recv packet: {:?}", packet);
513 if Self::try_handle_pingpong(&data, &packet, &output_sender).await {
514 continue;
515 }
516
517 let conv = ConnId::from(&packet);
518 if packet.header().is_data() && packet.payload().len() > 0 {
519 if let Some(mut conn) = data.conn_map.get_mut(&conv) {
520 if let Err(e) = conn.handle_input(&packet) {
521 tracing::error!(?e, ?conv, "handle input on connection failed");
522 } else {
523 tracing::trace!(?conv, "handle input on connection done");
524 }
525 } else {
526 tracing::debug!(
527 ?conv,
528 ?packet,
529 "no conn for conv when handling data packet"
530 );
531 }
532 }
533
534 let mut state_ref = data.state_map.get_mut(&conv);
535 let state = state_ref.as_deref_mut();
536 let mut out_packet: Option<KcpPacket> = None;
537 if state.is_none() {
538 if packet.header().is_rst() {
539 tracing::debug!(?conv, "reset packet for conn, but no state");
540 continue;
541 }
542 let mut tmp_fsm = KcpConnectionFSM::listen();
543 let res = tmp_fsm.handle_packet(&packet, &mut out_packet);
544 tracing::trace!(
545 ?conv,
546 ?state,
547 ?out_packet,
548 "handle first packet for conn, ret: {:?}",
549 res
550 );
551 if res.is_ok() {
552 let mut conn_state = KcpConnectionState::new(tmp_fsm);
553 conn_state.set_data(packet.payload().to_vec().into());
554 data.state_map.insert(conv, conn_state);
555 }
556 } else {
557 let state = state.unwrap();
558 let prev_established = state.is_established();
559 let ret = state.handle_packet(&packet);
560 tracing::trace!(?conv, ?state, "handle packet for conn, ret: {:?}", ret);
561 if ret.is_ok() {
562 out_packet = ret.unwrap();
563 }
564
565 if !prev_established && state.is_established() {
566 let _ = new_conn_sender.try_send(conv);
567 }
568
569 if state.is_peer_closed() {
570 tracing::debug!(?conv, "peer half closed, close recv");
571 data.conn_map.get_mut(&conv).map(|conn| conn.close_recv());
572 }
573
574 if state.is_closed() {
575 tracing::debug!(?conv, "connection closed, remove state");
577 data.conn_map.remove(&conv);
578 }
579 }
580
581 drop(state_ref);
582 if let Some(mut out_packet) = out_packet {
583 conv.fill_packet_header(&mut out_packet);
584 tracing::trace!(?conv, ?out_packet, "sending output packet");
585 let ret = output_sender.send(out_packet).await;
586 if let Err(e) = ret {
587 tracing::error!(?e, "send output packet failed");
588 }
589 }
590 }
591 }
592 .instrument(tracing::trace_span!("recv_task", id = self.id)),
593 );
594
595 let data = self.data.clone();
597 self.tasks.spawn(async move {
598 loop {
599 data.state_map.retain(|_, state| {
600 !matches!(state.fsm, KcpConnectionFSM::Closed) && !state.is_pong_timeout()
601 });
602 data.conn_map
603 .retain(|conn_id, _| data.state_map.contains_key(conn_id));
604 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
605 }
606 });
607
608 let data = self.data.clone();
610 let output_sender = self.output_sender.clone();
611 self.tasks.spawn(async move {
612 loop {
613 let packets = data
614 .state_map
615 .iter()
616 .filter_map(|item| {
617 let (conn_id, state) = item.pair();
618 if state.is_closed() {
619 return None;
620 }
621 let mut out_packet = KcpPacket::new(0);
622 conn_id.fill_packet_header(&mut out_packet);
623 out_packet.mut_header().set_ping(true);
624 Some(out_packet)
625 })
626 .collect::<Vec<_>>();
627
628 for packet in packets {
629 let ret = output_sender.send(packet).await;
630 if let Err(e) = ret {
631 tracing::error!(?e, "send ping packet failed");
632 }
633 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
634 }
635
636 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
637 }
638 });
639 }
640
641 fn add_conn(&self, conn_id: ConnId) -> Result<(), Error> {
642 let mut conn = KcpConnection::new(conn_id)?;
643 conn.run(self.output_sender.clone());
644
645 let data = self.data.clone();
646 let close_notifier = conn.send_close_notifier();
647
648 data.conn_map.insert(conn_id, conn);
649
650 let output_sender = self.output_sender.clone();
651 let data = Arc::downgrade(&data);
652 tokio::spawn(async move {
653 close_notifier.notified().await;
654 let Some(data) = data.upgrade() else {
655 return;
656 };
657 let mut out_packet = KcpPacket::new(0);
658 let Some(mut state) = data.state_map.get_mut(&conn_id) else {
659 return;
660 };
661
662 let close_ret = state.fsm.close(&mut out_packet);
663 let cur_state = state.fsm.clone();
664 let is_closed = state.is_closed();
665 drop(state);
666 match close_ret {
667 Ok(_) => {
668 conn_id.fill_packet_header(&mut out_packet);
669 output_sender.send(out_packet).await.unwrap();
670 }
671 Err(e) => {
672 tracing::error!(?e, ?conn_id, "close connection failed");
673 }
674 }
675
676 if is_closed {
677 data.conn_map.remove(&conn_id);
678 }
679
680 tracing::debug!(?conn_id, ?cur_state, "connection close watcher done");
681 });
682
683 Ok(())
684 }
685
686 pub fn output_receiver(&mut self) -> Option<KcpPacketReceiver> {
687 self.output_receiver.take()
688 }
689
690 pub fn input_sender(&self) -> KcpPakcetSender {
691 self.input_sender.clone()
692 }
693
694 pub fn input_sender_ref(&self) -> &KcpPakcetSender {
695 &self.input_sender
696 }
697
698 pub fn conn_sender_receiver(
699 &self,
700 conn_id: ConnId,
701 ) -> Option<(KcpStreamSender, KcpStreamReceiver)> {
702 let mut conn = self.data.conn_map.get_mut(&conn_id)?;
703 Some((conn.send_sender(), conn.recv_receiver()))
704 }
705
706 pub fn conn_data(&self, conn_id: &ConnId) -> Option<Bytes> {
707 let state = self.data.state_map.get(conn_id)?;
708 Some(state.conn_data.clone())
709 }
710
711 #[tracing::instrument(ret)]
712 pub async fn connect(
713 &self,
714 timeout_dur: std::time::Duration,
715 src_session_id: u32,
716 dst_session_id: u32,
717 conn_data: Bytes,
718 ) -> Result<ConnId, Error> {
719 let mut out_packet = KcpPacket::new_with_payload(&conn_data);
720 let conn_id = loop {
721 let conv_cand = self
722 .data
723 .cur_conv
724 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
725 let conn_id = ConnId {
726 conv: conv_cand,
727 src_session_id,
728 dst_session_id,
729 };
730 if !self.data.state_map.contains_key(&conn_id) {
731 break conn_id;
732 }
733 };
734
735 let fsm = KcpConnectionFSM::connect(&mut out_packet);
736 let mut state = KcpConnectionState::new(fsm);
737 state.set_data(conn_data);
738 let notify = state.notify();
739 self.data.state_map.insert(conn_id, state);
740
741 conn_id.fill_packet_header(&mut out_packet);
742
743 tracing::trace!(?conn_id, "connect packet: {:?}", out_packet);
744 self.output_sender
745 .send(out_packet)
746 .await
747 .with_context(|| "send connect packet failed")?;
748
749 if timeout(timeout_dur, notify.notified()).await.is_err() {
750 self.data.state_map.remove(&conn_id);
751 return Err(Error::ConnectTimeout);
752 }
753
754 if let Some(state) = self.data.state_map.get(&conn_id) {
755 tracing::debug!(?conn_id, ?state, "connect done, checkin state");
756 if matches!(state.fsm, KcpConnectionFSM::Established) {
757 self.add_conn(conn_id)?;
758 return Ok(conn_id);
759 } else {
760 drop(state);
761 self.data.state_map.remove(&conn_id);
762 }
763 }
765
766 return Err(anyhow::anyhow!("connect failed").into());
767 }
768
769 pub async fn accept(&self) -> Result<ConnId, Error> {
770 let conn_receiver = self.new_conn_receiver.clone();
771
772 loop {
773 let Some(conn_id) = conn_receiver.lock().await.recv().await else {
774 return Err(Error::Shutdown);
775 };
776
777 let Some(state) = self.data.state_map.get(&conn_id) else {
778 tracing::debug!(?conn_id, "no state for conn, ignore");
779 continue;
780 };
781
782 if matches!(state.fsm, KcpConnectionFSM::Established) {
783 self.add_conn(conn_id)?;
784 return Ok(conn_id);
785 }
786 }
787 }
788}
789
790#[cfg(test)]
791mod tests {
792 use tracing::level_filters::LevelFilter;
793 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer as _};
794
795 use super::*;
796
797 fn _enable_log() {
798 let console_layer = tracing_subscriber::fmt::layer()
799 .pretty()
800 .with_writer(std::io::stderr)
801 .with_filter(LevelFilter::TRACE);
802
803 tracing_subscriber::Registry::default()
804 .with(console_layer)
805 .init();
806 }
807
808 async fn prepare_test() -> (KcpEndpoint, KcpEndpoint, JoinSet<()>) {
809 let mut client_endpoint = KcpEndpoint::new();
810 let mut server_endpoint = KcpEndpoint::new();
811 let mut t = JoinSet::new();
812
813 client_endpoint.run().await;
814 server_endpoint.run().await;
815
816 let client_input_sender = client_endpoint.input_sender();
817 let mut server_output_receiver = server_endpoint.output_receiver().unwrap();
818 t.spawn(async move {
819 while let Some(packet) = server_output_receiver.recv().await {
820 let _ = client_input_sender.send(packet).await;
821 }
822 });
823
824 let server_input_sender = server_endpoint.input_sender();
825 let mut client_output_receiver = client_endpoint.output_receiver().unwrap();
826 t.spawn(async move {
827 while let Some(packet) = client_output_receiver.recv().await {
828 let _ = server_input_sender.send(packet).await;
829 }
830 });
831
832 (client_endpoint, server_endpoint, t)
833 }
834
835 #[tokio::test]
836 async fn test_kcp_connect_and_close() {
837 let mut p = KcpPacket::new(0);
838 let _ = p.mut_header().conv();
839
840 let (client_endpoint, server_endpoint, t) = prepare_test().await;
841
842 let (connect_ret, accept_ret) = tokio::join!(
843 client_endpoint.connect(std::time::Duration::from_secs(1), 1, 3, Bytes::from("conn")),
844 server_endpoint.accept()
845 );
846
847 assert_eq!(*connect_ret.as_ref().unwrap(), accept_ret.unwrap());
848
849 let conv = connect_ret.unwrap();
850
851 let client_conn_data = client_endpoint.conn_data(&conv).unwrap();
852 assert_eq!("conn", String::from_utf8_lossy(&client_conn_data));
853
854 let server_conn_data = server_endpoint.conn_data(&conv).unwrap();
855 assert_eq!("conn", String::from_utf8_lossy(&server_conn_data));
856
857 let (client_sender, mut client_receiver) =
858 client_endpoint.conn_sender_receiver(conv).unwrap();
859 let (server_sender, mut server_receiver) =
860 server_endpoint.conn_sender_receiver(conv).unwrap();
861
862 client_sender.send(BytesMut::from("hello")).await.unwrap();
863 let data = server_receiver.recv().await.unwrap();
864 assert_eq!("hello", String::from_utf8_lossy(&data));
865
866 server_sender.send(BytesMut::from("world")).await.unwrap();
867 let data = client_receiver.recv().await.unwrap();
868 assert_eq!("world", String::from_utf8_lossy(&data));
869
870 drop(client_sender);
872 assert!(server_receiver.recv().await.is_none());
873 server_sender.send(BytesMut::from("world")).await.unwrap();
875 let data = client_receiver.recv().await.unwrap();
876 assert_eq!("world", String::from_utf8_lossy(&data));
877
878 drop(server_sender);
880 assert!(client_receiver.recv().await.is_none());
881
882 drop(client_endpoint);
883 drop(server_endpoint);
884
885 t.join_all().await;
886 }
887}