kcp_sys/
endpoint.rs

1use std::sync::{
2    atomic::{AtomicBool, AtomicU32},
3    Arc,
4};
5
6use anyhow::Context;
7use bytes::{Bytes, BytesMut};
8use dashmap::DashMap;
9use parking_lot::Mutex;
10use tokio::{select, sync::Notify, task::JoinSet, time::timeout};
11use tracing::Instrument;
12
13use crate::{
14    error::Error,
15    ffi_safe::{Kcp, KcpConfig},
16    packet_def::KcpPacket,
17    state::{KcpConnectionFSM, PacketHeaderFlagManipulator},
18};
19
20pub type Sender<T> = tokio::sync::mpsc::Sender<T>;
21pub type Receiver<T> = tokio::sync::mpsc::Receiver<T>;
22
23pub type KcpPakcetSender = Sender<KcpPacket>;
24pub type KcpPacketReceiver = Receiver<KcpPacket>;
25
26pub type KcpStreamSender = Sender<BytesMut>;
27pub type KcpStreamReceiver = Receiver<BytesMut>;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct ConnId {
31    conv: u32,
32    src_session_id: u32,
33    dst_session_id: u32,
34}
35
36impl From<&KcpPacket> for ConnId {
37    fn from(packet: &KcpPacket) -> Self {
38        Self {
39            conv: packet.header().conv(),
40            src_session_id: packet.header().src_session_id(),
41            dst_session_id: packet.header().dst_session_id(),
42        }
43    }
44}
45
46impl ConnId {
47    fn fill_packet_header(&self, packet: &mut KcpPacket) {
48        packet
49            .mut_header()
50            .set_conv(self.conv)
51            .set_src_session_id(self.src_session_id)
52            .set_dst_session_id(self.dst_session_id);
53    }
54}
55
56struct KcpConnectionInner {
57    update_notifier: Notify,
58    recv_notifier: Notify,
59    send_notifier: Notify,
60
61    has_new_input: AtomicBool,
62    waiting_new_send_window: AtomicBool,
63}
64
65struct KcpConnection {
66    conn_id: ConnId,
67    kcp: Arc<Mutex<Box<Kcp>>>,
68
69    inner: Arc<KcpConnectionInner>,
70
71    send_sender: Option<Sender<BytesMut>>,
72    send_receiver: Option<Receiver<BytesMut>>,
73
74    recv_sender: Option<Sender<BytesMut>>,
75    recv_receiver: Option<Receiver<BytesMut>>,
76
77    send_close_notifier: Arc<Notify>,
78    recv_closed: Arc<AtomicBool>,
79
80    tasks: JoinSet<()>,
81}
82
83impl KcpConnection {
84    pub fn new(conn_id: ConnId) -> Result<Self, Error> {
85        let kcp = Kcp::new(KcpConfig::new_turbo(conn_id.conv))?;
86
87        let (send_sender, send_receiver) = tokio::sync::mpsc::channel(128);
88        let (recv_sender, recv_receiver) = tokio::sync::mpsc::channel(128);
89
90        Ok(Self {
91            conn_id,
92            kcp: Arc::new(Mutex::new(kcp)),
93
94            inner: Arc::new(KcpConnectionInner {
95                update_notifier: Notify::new(),
96                recv_notifier: Notify::new(),
97                send_notifier: Notify::new(),
98
99                has_new_input: AtomicBool::new(false),
100                waiting_new_send_window: AtomicBool::new(false),
101            }),
102
103            send_sender: Some(send_sender),
104            send_receiver: Some(send_receiver),
105
106            recv_sender: Some(recv_sender),
107            recv_receiver: Some(recv_receiver),
108
109            send_close_notifier: Arc::new(Notify::new()),
110            recv_closed: Arc::new(AtomicBool::new(false)),
111
112            tasks: JoinSet::new(),
113        })
114    }
115
116    pub fn run(&mut self, output_sender: KcpPakcetSender) {
117        let conn_id = self.conn_id;
118        self.kcp
119            .lock()
120            .set_output_cb(Box::new(move |conv, data: BytesMut| {
121                let mut kcp_packet = KcpPacket::new_with_payload(&data);
122                conn_id.fill_packet_header(&mut kcp_packet);
123                kcp_packet.mut_header().set_data(true).set_ack(true);
124                tracing::trace!(?conv, "sending output data: {:?}", kcp_packet);
125                if let Err(e) = output_sender.try_send(kcp_packet) {
126                    tracing::debug!(?e, ?conn_id, "send output data failed");
127                }
128                Ok(())
129            }));
130
131        // kcp updater
132        let inner = self.inner.clone();
133        let kcp = self.kcp.clone();
134        let recv_closed = self.recv_closed.clone();
135        self.tasks.spawn(async move {
136            loop {
137                let next_update_ms = kcp.lock().next_update_delay_ms();
138                select! {
139                    _ = tokio::time::sleep(tokio::time::Duration::from_millis(next_update_ms as u64)) => {}
140                    _ = inner.update_notifier.notified() => {}
141                }
142
143                kcp.lock().update();
144
145                if inner.has_new_input.swap(false, std::sync::atomic::Ordering::SeqCst) {
146                    inner.recv_notifier.notify_one();
147                }
148
149                if inner.waiting_new_send_window.swap(false, std::sync::atomic::Ordering::SeqCst) {
150                    inner.send_notifier.notify_one();
151                }
152
153                if recv_closed.load(std::sync::atomic::Ordering::Relaxed) {
154                    inner.recv_notifier.notify_one();
155                }
156            }
157        });
158
159        // handle packet send
160        let kcp = self.kcp.clone();
161        let inner = self.inner.clone();
162        let mut send_receiver = self.send_receiver.take().unwrap();
163        let send_close_notifier = self.send_close_notifier.clone();
164        self.tasks.spawn(
165            async move {
166                while let Some(data) = send_receiver.recv().await {
167                    loop {
168                        let (waitsnd, sndwnd) = {
169                            let kcp = kcp.lock();
170                            (kcp.waitsnd(), kcp.sendwnd())
171                        };
172                        if waitsnd > 2 * sndwnd {
173                            inner
174                                .waiting_new_send_window
175                                .store(true, std::sync::atomic::Ordering::SeqCst);
176                            inner.send_notifier.notified().await;
177                        } else {
178                            break;
179                        }
180                    }
181                    kcp.lock().send(data.freeze()).unwrap();
182                    kcp.lock().flush();
183                    inner.update_notifier.notify_one();
184                }
185
186                tracing::debug!(
187                    ?conn_id,
188                    "connection packet sender close, waiting for waitsnd to be 0"
189                );
190
191                // waiting for waitsnd to be 0
192                while kcp.lock().waitsnd() > 0 {
193                    inner
194                        .waiting_new_send_window
195                        .store(true, std::sync::atomic::Ordering::SeqCst);
196                    inner.send_notifier.notified().await;
197                }
198
199                send_close_notifier.notify_one();
200                tracing::debug!(?conn_id, "connection packet send task done");
201            }
202            .instrument(tracing::trace_span!("send_task", conn = ?conn_id)),
203        );
204
205        // handle packet recv
206        let kcp = self.kcp.clone();
207        let inner = self.inner.clone();
208        let conn_id = self.conn_id;
209        let recv_sender = self.recv_sender.take().unwrap();
210        let recv_closed = self.recv_closed.clone();
211        self.tasks.spawn(
212            async move {
213                let mut buf = BytesMut::new();
214                while !recv_closed.load(std::sync::atomic::Ordering::Relaxed) {
215                    let peeksize = kcp.lock().peeksize();
216                    if peeksize <= 0 {
217                        tracing::trace!("recv nothing, wait for next update");
218                        inner.recv_notifier.notified().await;
219                        continue;
220                    };
221
222                    if buf.capacity() < peeksize as usize {
223                        buf.reserve(std::cmp::max(peeksize as usize, 4096));
224                    }
225                    kcp.lock().recv(&mut buf).unwrap();
226                    tracing::trace!("recv data ({}): {:?}", buf.len(), buf);
227                    assert_ne!(0, buf.len());
228                    let send_ret = recv_sender.send(buf.split()).await;
229                    if let Err(_) = send_ret {
230                        break;
231                    }
232                }
233
234                tracing::debug!(?conn_id, "connection packet recv task done");
235            }
236            .instrument(tracing::trace_span!("recv_task", conn = ?conn_id)),
237        );
238    }
239
240    fn handle_input(&mut self, packet: &KcpPacket) -> Result<(), Error> {
241        self.kcp.lock().handle_input(packet.payload())?;
242        self.inner
243            .has_new_input
244            .store(true, std::sync::atomic::Ordering::SeqCst);
245        self.inner.update_notifier.notify_one();
246        Ok(())
247    }
248
249    fn send_sender(&mut self) -> KcpStreamSender {
250        self.send_sender.take().unwrap()
251    }
252
253    fn recv_receiver(&mut self) -> KcpStreamReceiver {
254        self.recv_receiver.take().unwrap()
255    }
256
257    fn send_close_notifier(&self) -> Arc<Notify> {
258        self.send_close_notifier.clone()
259    }
260
261    fn close_recv(&self) {
262        self.recv_closed
263            .store(true, std::sync::atomic::Ordering::SeqCst);
264        self.inner.recv_notifier.notify_one();
265    }
266}
267
268impl Drop for KcpConnection {
269    fn drop(&mut self) {
270        self.send_close_notifier.notify_one();
271    }
272}
273
274impl PacketHeaderFlagManipulator for KcpPacket {
275    fn has_syn(&self) -> bool {
276        self.header().is_syn()
277    }
278
279    fn has_ack(&self) -> bool {
280        self.header().is_ack()
281    }
282
283    fn has_fin(&self) -> bool {
284        self.header().is_fin()
285    }
286
287    fn has_rst(&self) -> bool {
288        self.header().is_rst()
289    }
290
291    fn has_data(&self) -> bool {
292        self.header().is_data()
293    }
294
295    fn set_syn(&mut self, value: bool) {
296        self.mut_header().set_syn(value);
297    }
298
299    fn set_ack(&mut self, value: bool) {
300        self.mut_header().set_ack(value);
301    }
302
303    fn set_fin(&mut self, value: bool) {
304        self.mut_header().set_fin(value);
305    }
306
307    fn set_rst(&mut self, value: bool) {
308        self.mut_header().set_rst(value);
309    }
310
311    fn set_data(&mut self, value: bool) {
312        self.mut_header().set_data(value);
313    }
314}
315
316struct KcpConnectionState {
317    fsm: KcpConnectionFSM,
318    notify: Arc<Notify>,
319    conn_data: Bytes,
320    last_pong: std::time::Instant,
321}
322
323impl std::fmt::Debug for KcpConnectionState {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        f.debug_struct("KcpConnectionState")
326            .field("fsm", &self.fsm)
327            .finish()
328    }
329}
330
331impl KcpConnectionState {
332    fn new(fsm: KcpConnectionFSM) -> Self {
333        Self {
334            fsm,
335            notify: Arc::new(Notify::new()),
336            conn_data: Bytes::new(),
337            last_pong: std::time::Instant::now(),
338        }
339    }
340
341    fn handle_packet(&mut self, packet: &KcpPacket) -> Result<Option<KcpPacket>, Error> {
342        self.notify_pong();
343        let mut out_packet = None;
344        let old_state = self.fsm.clone();
345        let _ = self.fsm.handle_packet(packet, &mut out_packet);
346        if old_state != self.fsm {
347            self.notify.notify_one();
348            return Ok(out_packet);
349        }
350        Ok(None)
351    }
352
353    fn notify(&self) -> Arc<Notify> {
354        self.notify.clone()
355    }
356
357    fn is_established(&self) -> bool {
358        matches!(self.fsm, KcpConnectionFSM::Established)
359    }
360
361    fn is_peer_closed(&self) -> bool {
362        matches!(
363            self.fsm,
364            KcpConnectionFSM::PeerClosed | KcpConnectionFSM::Closed
365        )
366    }
367
368    fn is_local_closed(&self) -> bool {
369        matches!(
370            self.fsm,
371            KcpConnectionFSM::LocalClosed | KcpConnectionFSM::Closed
372        )
373    }
374
375    fn is_closed(&self) -> bool {
376        matches!(self.fsm, KcpConnectionFSM::Closed)
377    }
378
379    fn set_data(&mut self, data: Bytes) {
380        self.conn_data = data;
381    }
382
383    fn notify_pong(&mut self) {
384        self.last_pong = std::time::Instant::now();
385    }
386
387    fn is_pong_timeout(&self) -> bool {
388        self.last_pong.elapsed() > std::time::Duration::from_secs(60)
389    }
390}
391
392struct KcpEndpointData {
393    cur_conv: AtomicU32,
394    conn_map: DashMap<ConnId, KcpConnection>,
395    state_map: DashMap<ConnId, KcpConnectionState>,
396}
397
398impl KcpEndpointData {
399    fn new() -> Self {
400        Self {
401            cur_conv: AtomicU32::new(rand::random()),
402            conn_map: DashMap::new(),
403            state_map: DashMap::new(),
404        }
405    }
406}
407
408pub type KcpConfigFactory = Box<dyn Fn(u32) -> KcpConfig + Send + Sync>;
409
410pub struct KcpEndpoint {
411    id: u64,
412    data: Arc<KcpEndpointData>,
413
414    input_sender: KcpPakcetSender,
415    input_receiver: Option<KcpPacketReceiver>,
416
417    output_sender: KcpPakcetSender,
418    output_receiver: Option<KcpPacketReceiver>,
419
420    new_conn_sender: tokio::sync::mpsc::Sender<ConnId>,
421    new_conn_receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<ConnId>>>,
422
423    kcp_config_factory: KcpConfigFactory,
424
425    tasks: JoinSet<()>,
426}
427
428impl std::fmt::Debug for KcpEndpoint {
429    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430        f.debug_struct("KcpEndpoint").field("id", &self.id).finish()
431    }
432}
433
434impl KcpEndpoint {
435    pub fn new() -> Self {
436        let (input_sender, input_receiver) = tokio::sync::mpsc::channel(1024);
437        let (output_sender, output_receiver) = tokio::sync::mpsc::channel(1024);
438        let (new_conn_sender, new_conn_receiver) = tokio::sync::mpsc::channel(4);
439
440        Self {
441            id: rand::random(),
442            data: Arc::new(KcpEndpointData::new()),
443
444            input_sender,
445            input_receiver: Some(input_receiver),
446
447            output_sender,
448            output_receiver: Some(output_receiver),
449
450            new_conn_sender,
451            new_conn_receiver: Arc::new(tokio::sync::Mutex::new(new_conn_receiver)),
452
453            kcp_config_factory: Box::new(|conv| KcpConfig::new_turbo(conv)),
454
455            tasks: JoinSet::new(),
456        }
457    }
458
459    pub fn set_kcp_config_factory(&mut self, factory: KcpConfigFactory) {
460        self.kcp_config_factory = factory;
461    }
462
463    async fn try_handle_pingpong(
464        data: &KcpEndpointData,
465        packet: &KcpPacket,
466        output_sender: &KcpPakcetSender,
467    ) -> bool {
468        if !packet.header().is_ping() {
469            return false;
470        }
471
472        if !packet.header().is_pong() {
473            let conn_id = ConnId::from(packet);
474            let need_send_pong = data
475                .state_map
476                .get_mut(&conn_id)
477                .map(|x| !x.is_local_closed())
478                .unwrap_or(false);
479
480            let mut out_packet = packet.clone();
481            if need_send_pong {
482                out_packet.mut_header().set_pong(true);
483            } else {
484                out_packet.mut_header().set_ping(false);
485                out_packet.mut_header().set_rst(true);
486            };
487
488            tracing::trace!("sending pong packet: {:?}", out_packet);
489            let ret = output_sender.send(out_packet).await;
490            if let Err(e) = ret {
491                tracing::error!(?e, "send pong packet failed");
492            }
493        } else {
494            let conv = ConnId::from(packet);
495            if let Some(mut state) = data.state_map.get_mut(&conv) {
496                state.notify_pong();
497            }
498        }
499
500        true
501    }
502
503    pub async fn run(&mut self) {
504        let mut input_receiver = self.input_receiver.take().unwrap();
505        let data = self.data.clone();
506        let output_sender = self.output_sender.clone();
507        let new_conn_sender = self.new_conn_sender.clone();
508
509        self.tasks.spawn(
510            async move {
511                while let Some(packet) = input_receiver.recv().await {
512                    tracing::trace!("recv packet: {:?}", packet);
513                    if Self::try_handle_pingpong(&data, &packet, &output_sender).await {
514                        continue;
515                    }
516
517                    let conv = ConnId::from(&packet);
518                    if packet.header().is_data() && packet.payload().len() > 0 {
519                        if let Some(mut conn) = data.conn_map.get_mut(&conv) {
520                            if let Err(e) = conn.handle_input(&packet) {
521                                tracing::error!(?e, ?conv, "handle input on connection failed");
522                            } else {
523                                tracing::trace!(?conv, "handle input on connection done");
524                            }
525                        } else {
526                            tracing::debug!(
527                                ?conv,
528                                ?packet,
529                                "no conn for conv when handling data packet"
530                            );
531                        }
532                    }
533
534                    let mut state_ref = data.state_map.get_mut(&conv);
535                    let state = state_ref.as_deref_mut();
536                    let mut out_packet: Option<KcpPacket> = None;
537                    if state.is_none() {
538                        if packet.header().is_rst() {
539                            tracing::debug!(?conv, "reset packet for conn, but no state");
540                            continue;
541                        }
542                        let mut tmp_fsm = KcpConnectionFSM::listen();
543                        let res = tmp_fsm.handle_packet(&packet, &mut out_packet);
544                        tracing::trace!(
545                            ?conv,
546                            ?state,
547                            ?out_packet,
548                            "handle first packet for conn, ret: {:?}",
549                            res
550                        );
551                        if res.is_ok() {
552                            let mut conn_state = KcpConnectionState::new(tmp_fsm);
553                            conn_state.set_data(packet.payload().to_vec().into());
554                            data.state_map.insert(conv, conn_state);
555                        }
556                    } else {
557                        let state = state.unwrap();
558                        let prev_established = state.is_established();
559                        let ret = state.handle_packet(&packet);
560                        tracing::trace!(?conv, ?state, "handle packet for conn, ret: {:?}", ret);
561                        if ret.is_ok() {
562                            out_packet = ret.unwrap();
563                        }
564
565                        if !prev_established && state.is_established() {
566                            let _ = new_conn_sender.try_send(conv);
567                        }
568
569                        if state.is_peer_closed() {
570                            tracing::debug!(?conv, "peer half closed, close recv");
571                            data.conn_map.get_mut(&conv).map(|conn| conn.close_recv());
572                        }
573
574                        if state.is_closed() {
575                            // state map will be cleaned by periodic task
576                            tracing::debug!(?conv, "connection closed, remove state");
577                            data.conn_map.remove(&conv);
578                        }
579                    }
580
581                    drop(state_ref);
582                    if let Some(mut out_packet) = out_packet {
583                        conv.fill_packet_header(&mut out_packet);
584                        tracing::trace!(?conv, ?out_packet, "sending output packet");
585                        let ret = output_sender.send(out_packet).await;
586                        if let Err(e) = ret {
587                            tracing::error!(?e, "send output packet failed");
588                        }
589                    }
590                }
591            }
592            .instrument(tracing::trace_span!("recv_task", id = self.id)),
593        );
594
595        // conn clean task
596        let data = self.data.clone();
597        self.tasks.spawn(async move {
598            loop {
599                data.state_map.retain(|_, state| {
600                    !matches!(state.fsm, KcpConnectionFSM::Closed) && !state.is_pong_timeout()
601                });
602                data.conn_map
603                    .retain(|conn_id, _| data.state_map.contains_key(conn_id));
604                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
605            }
606        });
607
608        // conn ping task
609        let data = self.data.clone();
610        let output_sender = self.output_sender.clone();
611        self.tasks.spawn(async move {
612            loop {
613                let packets = data
614                    .state_map
615                    .iter()
616                    .filter_map(|item| {
617                        let (conn_id, state) = item.pair();
618                        if state.is_closed() {
619                            return None;
620                        }
621                        let mut out_packet = KcpPacket::new(0);
622                        conn_id.fill_packet_header(&mut out_packet);
623                        out_packet.mut_header().set_ping(true);
624                        Some(out_packet)
625                    })
626                    .collect::<Vec<_>>();
627
628                for packet in packets {
629                    let ret = output_sender.send(packet).await;
630                    if let Err(e) = ret {
631                        tracing::error!(?e, "send ping packet failed");
632                    }
633                    tokio::time::sleep(std::time::Duration::from_millis(5)).await;
634                }
635
636                tokio::time::sleep(std::time::Duration::from_secs(10)).await;
637            }
638        });
639    }
640
641    fn add_conn(&self, conn_id: ConnId) -> Result<(), Error> {
642        let mut conn = KcpConnection::new(conn_id)?;
643        conn.run(self.output_sender.clone());
644
645        let data = self.data.clone();
646        let close_notifier = conn.send_close_notifier();
647
648        data.conn_map.insert(conn_id, conn);
649
650        let output_sender = self.output_sender.clone();
651        let data = Arc::downgrade(&data);
652        tokio::spawn(async move {
653            close_notifier.notified().await;
654            let Some(data) = data.upgrade() else {
655                return;
656            };
657            let mut out_packet = KcpPacket::new(0);
658            let Some(mut state) = data.state_map.get_mut(&conn_id) else {
659                return;
660            };
661
662            let close_ret = state.fsm.close(&mut out_packet);
663            let cur_state = state.fsm.clone();
664            let is_closed = state.is_closed();
665            drop(state);
666            match close_ret {
667                Ok(_) => {
668                    conn_id.fill_packet_header(&mut out_packet);
669                    output_sender.send(out_packet).await.unwrap();
670                }
671                Err(e) => {
672                    tracing::error!(?e, ?conn_id, "close connection failed");
673                }
674            }
675
676            if is_closed {
677                data.conn_map.remove(&conn_id);
678            }
679
680            tracing::debug!(?conn_id, ?cur_state, "connection close watcher done");
681        });
682
683        Ok(())
684    }
685
686    pub fn output_receiver(&mut self) -> Option<KcpPacketReceiver> {
687        self.output_receiver.take()
688    }
689
690    pub fn input_sender(&self) -> KcpPakcetSender {
691        self.input_sender.clone()
692    }
693
694    pub fn input_sender_ref(&self) -> &KcpPakcetSender {
695        &self.input_sender
696    }
697
698    pub fn conn_sender_receiver(
699        &self,
700        conn_id: ConnId,
701    ) -> Option<(KcpStreamSender, KcpStreamReceiver)> {
702        let mut conn = self.data.conn_map.get_mut(&conn_id)?;
703        Some((conn.send_sender(), conn.recv_receiver()))
704    }
705
706    pub fn conn_data(&self, conn_id: &ConnId) -> Option<Bytes> {
707        let state = self.data.state_map.get(conn_id)?;
708        Some(state.conn_data.clone())
709    }
710
711    #[tracing::instrument(ret)]
712    pub async fn connect(
713        &self,
714        timeout_dur: std::time::Duration,
715        src_session_id: u32,
716        dst_session_id: u32,
717        conn_data: Bytes,
718    ) -> Result<ConnId, Error> {
719        let mut out_packet = KcpPacket::new_with_payload(&conn_data);
720        let conn_id = loop {
721            let conv_cand = self
722                .data
723                .cur_conv
724                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
725            let conn_id = ConnId {
726                conv: conv_cand,
727                src_session_id,
728                dst_session_id,
729            };
730            if !self.data.state_map.contains_key(&conn_id) {
731                break conn_id;
732            }
733        };
734
735        let fsm = KcpConnectionFSM::connect(&mut out_packet);
736        let mut state = KcpConnectionState::new(fsm);
737        state.set_data(conn_data);
738        let notify = state.notify();
739        self.data.state_map.insert(conn_id, state);
740
741        conn_id.fill_packet_header(&mut out_packet);
742
743        tracing::trace!(?conn_id, "connect packet: {:?}", out_packet);
744        self.output_sender
745            .send(out_packet)
746            .await
747            .with_context(|| "send connect packet failed")?;
748
749        if timeout(timeout_dur, notify.notified()).await.is_err() {
750            self.data.state_map.remove(&conn_id);
751            return Err(Error::ConnectTimeout);
752        }
753
754        if let Some(state) = self.data.state_map.get(&conn_id) {
755            tracing::debug!(?conn_id, ?state, "connect done, checkin state");
756            if matches!(state.fsm, KcpConnectionFSM::Established) {
757                self.add_conn(conn_id)?;
758                return Ok(conn_id);
759            } else {
760                drop(state);
761                self.data.state_map.remove(&conn_id);
762            }
763            // if task aborted, the state map will be cleaned by periodic task
764        }
765
766        return Err(anyhow::anyhow!("connect failed").into());
767    }
768
769    pub async fn accept(&self) -> Result<ConnId, Error> {
770        let conn_receiver = self.new_conn_receiver.clone();
771
772        loop {
773            let Some(conn_id) = conn_receiver.lock().await.recv().await else {
774                return Err(Error::Shutdown);
775            };
776
777            let Some(state) = self.data.state_map.get(&conn_id) else {
778                tracing::debug!(?conn_id, "no state for conn, ignore");
779                continue;
780            };
781
782            if matches!(state.fsm, KcpConnectionFSM::Established) {
783                self.add_conn(conn_id)?;
784                return Ok(conn_id);
785            }
786        }
787    }
788}
789
790#[cfg(test)]
791mod tests {
792    use tracing::level_filters::LevelFilter;
793    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer as _};
794
795    use super::*;
796
797    fn _enable_log() {
798        let console_layer = tracing_subscriber::fmt::layer()
799            .pretty()
800            .with_writer(std::io::stderr)
801            .with_filter(LevelFilter::TRACE);
802
803        tracing_subscriber::Registry::default()
804            .with(console_layer)
805            .init();
806    }
807
808    async fn prepare_test() -> (KcpEndpoint, KcpEndpoint, JoinSet<()>) {
809        let mut client_endpoint = KcpEndpoint::new();
810        let mut server_endpoint = KcpEndpoint::new();
811        let mut t = JoinSet::new();
812
813        client_endpoint.run().await;
814        server_endpoint.run().await;
815
816        let client_input_sender = client_endpoint.input_sender();
817        let mut server_output_receiver = server_endpoint.output_receiver().unwrap();
818        t.spawn(async move {
819            while let Some(packet) = server_output_receiver.recv().await {
820                let _ = client_input_sender.send(packet).await;
821            }
822        });
823
824        let server_input_sender = server_endpoint.input_sender();
825        let mut client_output_receiver = client_endpoint.output_receiver().unwrap();
826        t.spawn(async move {
827            while let Some(packet) = client_output_receiver.recv().await {
828                let _ = server_input_sender.send(packet).await;
829            }
830        });
831
832        (client_endpoint, server_endpoint, t)
833    }
834
835    #[tokio::test]
836    async fn test_kcp_connect_and_close() {
837        let mut p = KcpPacket::new(0);
838        let _ = p.mut_header().conv();
839
840        let (client_endpoint, server_endpoint, t) = prepare_test().await;
841
842        let (connect_ret, accept_ret) = tokio::join!(
843            client_endpoint.connect(std::time::Duration::from_secs(1), 1, 3, Bytes::from("conn")),
844            server_endpoint.accept()
845        );
846
847        assert_eq!(*connect_ret.as_ref().unwrap(), accept_ret.unwrap());
848
849        let conv = connect_ret.unwrap();
850
851        let client_conn_data = client_endpoint.conn_data(&conv).unwrap();
852        assert_eq!("conn", String::from_utf8_lossy(&client_conn_data));
853
854        let server_conn_data = server_endpoint.conn_data(&conv).unwrap();
855        assert_eq!("conn", String::from_utf8_lossy(&server_conn_data));
856
857        let (client_sender, mut client_receiver) =
858            client_endpoint.conn_sender_receiver(conv).unwrap();
859        let (server_sender, mut server_receiver) =
860            server_endpoint.conn_sender_receiver(conv).unwrap();
861
862        client_sender.send(BytesMut::from("hello")).await.unwrap();
863        let data = server_receiver.recv().await.unwrap();
864        assert_eq!("hello", String::from_utf8_lossy(&data));
865
866        server_sender.send(BytesMut::from("world")).await.unwrap();
867        let data = client_receiver.recv().await.unwrap();
868        assert_eq!("world", String::from_utf8_lossy(&data));
869
870        // test half close
871        drop(client_sender);
872        assert!(server_receiver.recv().await.is_none());
873        // server can still send data
874        server_sender.send(BytesMut::from("world")).await.unwrap();
875        let data = client_receiver.recv().await.unwrap();
876        assert_eq!("world", String::from_utf8_lossy(&data));
877
878        // full close
879        drop(server_sender);
880        assert!(client_receiver.recv().await.is_none());
881
882        drop(client_endpoint);
883        drop(server_endpoint);
884
885        t.join_all().await;
886    }
887}