kcp/
udp.rs

1use crate::protocol::Kcp;
2use crate::transport::*;
3use crate::{conv::ConvCache, stream::*};
4
5use ::bytes::{Bytes, BytesMut};
6use ::futures::{
7    future::{poll_immediate, ready},
8    Sink, SinkExt, Stream, StreamExt,
9};
10use ::hashlink::LinkedHashMap;
11use ::std::{
12    io,
13    net::{Ipv4Addr, Ipv6Addr, SocketAddr},
14    pin::Pin,
15    sync::Arc,
16    task::{Context, Poll},
17};
18use ::tokio::{
19    net::{lookup_host, ToSocketAddrs, UdpSocket},
20    select,
21    sync::mpsc::{
22        channel, unbounded_channel, OwnedPermit, Receiver, Sender, UnboundedReceiver,
23        UnboundedSender,
24    },
25    task::JoinHandle,
26};
27use ::tokio_util::{codec::BytesCodec, sync::CancellationToken, udp::UdpFramed};
28
29pub struct KcpUdpStream {
30    config: Arc<KcpConfig>,
31    stream_rx: Receiver<(KcpStream, SocketAddr)>,
32    token: CancellationToken,
33    task: Option<JoinHandle<()>>,
34}
35
36impl KcpUdpStream {
37    pub async fn listen<A: ToSocketAddrs>(
38        config: Arc<KcpConfig>,
39        addr: A,
40        backlog: usize,
41        conv_cache: Option<ConvCache>,
42    ) -> io::Result<Self> {
43        let udp = UdpSocket::bind(addr).await?;
44        Self::socket_listen(config, udp, backlog, conv_cache)
45    }
46
47    pub fn socket_listen(
48        config: Arc<KcpConfig>,
49        udp: UdpSocket,
50        backlog: usize,
51        conv_cache: Option<ConvCache>,
52    ) -> io::Result<Self> {
53        let token = CancellationToken::new();
54        let (stream_tx, stream_rx) = channel(backlog.max(8));
55        let task = Task::new(config.clone(), conv_cache, stream_tx, token.clone());
56        Ok(Self {
57            config,
58            stream_rx,
59            token,
60            task: Some(tokio::spawn(task.run(udp))),
61        })
62    }
63
64    pub async fn accept(&mut self) -> io::Result<(KcpStream, SocketAddr)> {
65        self.stream_rx
66            .recv()
67            .await
68            .ok_or_else(|| io::Error::from(io::ErrorKind::NotConnected))
69    }
70
71    pub async fn close(&mut self) -> io::Result<()> {
72        if let Some(task) = self.task.take() {
73            self.token.cancel();
74            self.stream_rx.close();
75            let _ = task.await;
76        }
77        Ok(())
78    }
79}
80
81impl KcpUdpStream {
82    pub async fn connect<A: ToSocketAddrs>(
83        config: Arc<KcpConfig>,
84        addr: A,
85    ) -> io::Result<(KcpStream, SocketAddr)> {
86        let addr = lookup_host(addr)
87            .await?
88            .next()
89            .ok_or(io::ErrorKind::AddrNotAvailable)?;
90
91        let local_addr: SocketAddr = if addr.is_ipv4() {
92            (Ipv4Addr::UNSPECIFIED, 0).into()
93        } else {
94            (Ipv6Addr::UNSPECIFIED, 0).into()
95        };
96        let udp = UdpSocket::bind(local_addr).await?;
97
98        Self::socket_connect(config, addr, udp).await
99    }
100
101    pub async fn socket_connect<A: ToSocketAddrs>(
102        config: Arc<KcpConfig>,
103        addr: A,
104        udp: UdpSocket,
105    ) -> io::Result<(KcpStream, SocketAddr)> {
106        let addr = lookup_host(addr)
107            .await?
108            .next()
109            .ok_or(io::ErrorKind::AddrNotAvailable)?;
110
111        KcpStream::connect::<_, BytesMut, _>(
112            config,
113            UdpStream::new(udp, addr),
114            futures::sink::drain(),
115            None,
116        )
117        .await
118        .map(|x| (x, addr))
119    }
120}
121
122impl Drop for KcpUdpStream {
123    fn drop(&mut self) {
124        self.token.cancel();
125        self.stream_rx.close();
126    }
127}
128
129impl Stream for KcpUdpStream {
130    type Item = (KcpStream, SocketAddr);
131
132    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        self.stream_rx.poll_recv(cx)
134    }
135}
136
137////////////////////////////////////////////////////////////////////////////////
138
139struct Session {
140    conv: u32,
141    session_id: Bytes,
142    peer_addr: SocketAddr,
143    sender: Sender<BytesMut>,
144    stream_permit: Option<OwnedPermit<(KcpStream, SocketAddr)>>,
145    token: CancellationToken,
146    task: Option<JoinHandle<()>>,
147}
148
149enum Message {
150    Connect(KcpStream),
151    Disconnect { conv: u32 },
152}
153
154struct Task {
155    config: Arc<KcpConfig>,
156    conv_cache: ConvCache,
157    stream_tx: Sender<(KcpStream, SocketAddr)>,
158    msg_tx: UnboundedSender<Message>,
159    msg_rx: UnboundedReceiver<Message>,
160    pkt_tx: UnboundedSender<(Bytes, SocketAddr)>,
161    pkt_rx: UnboundedReceiver<(Bytes, SocketAddr)>,
162    token: CancellationToken,
163    is_closing: bool,
164
165    conv_map: LinkedHashMap<u32, Session>,
166    sid_map: LinkedHashMap<Bytes, u32>,
167}
168
169impl Task {
170    fn new(
171        config: Arc<KcpConfig>,
172        conv_cache: Option<ConvCache>,
173        stream_tx: Sender<(KcpStream, SocketAddr)>,
174        token: CancellationToken,
175    ) -> Self {
176        let (msg_tx, msg_rx) = unbounded_channel();
177        let (pkt_tx, pkt_rx) = unbounded_channel();
178        Self {
179            config,
180            conv_cache: conv_cache.unwrap_or_else(|| ConvCache::new(0, LISTENER_CONV_TIMEOUT)),
181            stream_tx,
182            msg_tx,
183            msg_rx,
184            pkt_tx,
185            pkt_rx,
186            token,
187            is_closing: false,
188            conv_map: LinkedHashMap::new(),
189            sid_map: LinkedHashMap::new(),
190        }
191    }
192
193    async fn run(mut self, udp: UdpSocket) {
194        let mut transport = UdpFramed::new(udp, BytesCodec::new());
195
196        loop {
197            if self.is_closing {
198                // Try to drain all connection messages.
199                match self.msg_rx.try_recv() {
200                    Ok(msg) => self.process_msg(msg).await,
201                    Err(_) if self.conv_map.is_empty() => break,
202                    _ => (),
203                }
204            }
205
206            select! {
207                x = transport.next() => {
208                    let mut recved = x;
209                    for _ in 0..LISTENER_TASK_LOOP {
210                        match recved {
211                            Some(Ok((packet, addr))) => {
212                                if let Some(session) = self.get_session(&packet, &addr) {
213                                    let _ = session.sender.send(packet.clone()).await;
214                                }
215                            }
216                            Some(Err(_)) => break,
217                            None => {
218                                self.is_closing = true;
219                                self.token.cancel();
220                                break;
221                            }
222                        }
223
224                        // Try to receive more.
225                        match poll_immediate(transport.next()).await {
226                            Some(x) => recved = x,
227                            _ => break,
228                        }
229                    }
230                }
231
232                Some(item) = self.pkt_rx.recv() => {
233                    let _ = transport.feed(item).await;
234                    // Try send more.
235                    self.try_send(&mut transport, LISTENER_TASK_LOOP).await;
236                }
237
238                Some(msg) = self.msg_rx.recv() => self.process_msg(msg).await,
239
240                _ = self.token.cancelled(), if !self.is_closing => {
241                    self.is_closing = true;
242                }
243            }
244        }
245
246        self.msg_rx.close();
247        self.pkt_rx.close();
248        self.try_send(&mut transport, usize::MAX).await;
249    }
250
251    async fn process_msg(&mut self, msg: Message) {
252        match msg {
253            Message::Connect(stream) => {
254                if let Some(session) = self.conv_map.get_mut(&stream.conv()) {
255                    if let Some(task) = session.task.take() {
256                        let _ = task.await;
257                    }
258                    if let Some(permit) = session.stream_permit.take() {
259                        permit.send((stream, session.peer_addr));
260                    }
261                }
262            }
263            Message::Disconnect { conv } => {
264                if let Some(session) = self.conv_map.remove(&conv) {
265                    self.kill_session(session).await;
266                }
267            }
268        }
269    }
270
271    async fn try_send<S: Sink<(Bytes, SocketAddr)> + Unpin>(&mut self, sink: &mut S, max: usize) {
272        for _ in 0..max {
273            match self.pkt_rx.try_recv() {
274                Ok(item) => {
275                    let _ = sink.feed(item).await;
276                }
277                _ => break,
278            }
279        }
280        let _ = sink.flush().await;
281    }
282
283    /// Get session by a received packet.
284    fn get_session(&mut self, packet: &[u8], peer_addr: &SocketAddr) -> Option<&Session> {
285        // Find session by conv.
286        let pkt_conv = match Kcp::read_conv(packet) {
287            Some(x) => match self.conv_map.get(&x) {
288                Some(s) if &s.peer_addr == peer_addr => return self.conv_map.get(&x),
289                Some(_) => return None,
290                _ => x,
291            },
292            _ => return None,
293        };
294
295        // Try to accept a new connection.
296        let session_id = match KcpStream::read_session_id(packet, &self.config.session_key) {
297            Some(x) => x,
298            _ => return None,
299        };
300
301        if let Some(&conv) = self.sid_map.get(session_id) {
302            if conv == pkt_conv || pkt_conv == Kcp::SYN_CONV {
303                match self.conv_map.get(&conv) {
304                    x @ Some(s) if &s.peer_addr == peer_addr => return x,
305                    _ => (),
306                }
307            }
308            None
309        } else if self.is_closing
310            || pkt_conv != Kcp::SYN_CONV
311            || session_id.len() != self.config.session_id_len
312        {
313            // It's not a SYN handshake packet.
314            None
315        } else {
316            // Get permit in the backlog limitation.
317            let stream_permit = match self.stream_tx.clone().try_reserve_owned() {
318                Ok(x) => x,
319                _ => return None,
320            };
321
322            // New KCP conv.
323            let conv = self.conv_cache.allocate(|x| self.conv_map.contains_key(x));
324
325            let (sender, receiver) = channel(self.config.snd_wnd as usize);
326            let token = self.token.child_token();
327
328            let session_id = Bytes::copy_from_slice(session_id);
329            self.sid_map.insert(session_id.clone(), conv);
330            self.conv_map.insert(
331                conv,
332                Session {
333                    conv,
334                    session_id,
335                    peer_addr: *peer_addr,
336                    sender,
337                    token: token.clone(),
338                    stream_permit: Some(stream_permit),
339                    task: Some(tokio::spawn(Self::accept_stream(
340                        self.config.clone(),
341                        conv,
342                        *peer_addr,
343                        receiver,
344                        self.pkt_tx.clone(),
345                        self.msg_tx.clone(),
346                        token,
347                    ))),
348                },
349            );
350            self.conv_map.get(&conv)
351        }
352    }
353
354    async fn accept_stream(
355        config: Arc<KcpConfig>,
356        conv: u32,
357        peer_addr: SocketAddr,
358        receiver: Receiver<BytesMut>,
359        pkt_tx: UnboundedSender<(Bytes, SocketAddr)>,
360        msg_tx: UnboundedSender<Message>,
361        token: CancellationToken,
362    ) {
363        let disconnect = UnboundedSink::new(msg_tx.clone())
364            .with(move |conv: u32| ready(Ok::<_, io::Error>(Message::Disconnect { conv })));
365        if let Ok(stream) = KcpStream::accept(
366            config,
367            conv,
368            UdpMpscStream::new(Some(pkt_tx), receiver, peer_addr),
369            disconnect,
370            Some(token),
371        )
372        .await
373        {
374            let _ = msg_tx.send(Message::Connect(stream));
375        }
376    }
377
378    async fn kill_session(&mut self, mut session: Session) {
379        // Add the killed conv to cache to avoid conv confliction.
380        self.conv_cache.add(session.conv);
381        self.sid_map.remove(&session.session_id);
382        if let Some(task) = session.task.take() {
383            session.token.cancel();
384            let _ = task.await;
385        }
386    }
387}