1use crate::protocol::Kcp;
2use crate::transport::*;
3use crate::{conv::ConvCache, stream::*};
4
5use ::bytes::{Bytes, BytesMut};
6use ::futures::{
7 future::{poll_immediate, ready},
8 Sink, SinkExt, Stream, StreamExt,
9};
10use ::hashlink::LinkedHashMap;
11use ::std::{
12 io,
13 net::{Ipv4Addr, Ipv6Addr, SocketAddr},
14 pin::Pin,
15 sync::Arc,
16 task::{Context, Poll},
17};
18use ::tokio::{
19 net::{lookup_host, ToSocketAddrs, UdpSocket},
20 select,
21 sync::mpsc::{
22 channel, unbounded_channel, OwnedPermit, Receiver, Sender, UnboundedReceiver,
23 UnboundedSender,
24 },
25 task::JoinHandle,
26};
27use ::tokio_util::{codec::BytesCodec, sync::CancellationToken, udp::UdpFramed};
28
29pub struct KcpUdpStream {
30 config: Arc<KcpConfig>,
31 stream_rx: Receiver<(KcpStream, SocketAddr)>,
32 token: CancellationToken,
33 task: Option<JoinHandle<()>>,
34}
35
36impl KcpUdpStream {
37 pub async fn listen<A: ToSocketAddrs>(
38 config: Arc<KcpConfig>,
39 addr: A,
40 backlog: usize,
41 conv_cache: Option<ConvCache>,
42 ) -> io::Result<Self> {
43 let udp = UdpSocket::bind(addr).await?;
44 Self::socket_listen(config, udp, backlog, conv_cache)
45 }
46
47 pub fn socket_listen(
48 config: Arc<KcpConfig>,
49 udp: UdpSocket,
50 backlog: usize,
51 conv_cache: Option<ConvCache>,
52 ) -> io::Result<Self> {
53 let token = CancellationToken::new();
54 let (stream_tx, stream_rx) = channel(backlog.max(8));
55 let task = Task::new(config.clone(), conv_cache, stream_tx, token.clone());
56 Ok(Self {
57 config,
58 stream_rx,
59 token,
60 task: Some(tokio::spawn(task.run(udp))),
61 })
62 }
63
64 pub async fn accept(&mut self) -> io::Result<(KcpStream, SocketAddr)> {
65 self.stream_rx
66 .recv()
67 .await
68 .ok_or_else(|| io::Error::from(io::ErrorKind::NotConnected))
69 }
70
71 pub async fn close(&mut self) -> io::Result<()> {
72 if let Some(task) = self.task.take() {
73 self.token.cancel();
74 self.stream_rx.close();
75 let _ = task.await;
76 }
77 Ok(())
78 }
79}
80
81impl KcpUdpStream {
82 pub async fn connect<A: ToSocketAddrs>(
83 config: Arc<KcpConfig>,
84 addr: A,
85 ) -> io::Result<(KcpStream, SocketAddr)> {
86 let addr = lookup_host(addr)
87 .await?
88 .next()
89 .ok_or(io::ErrorKind::AddrNotAvailable)?;
90
91 let local_addr: SocketAddr = if addr.is_ipv4() {
92 (Ipv4Addr::UNSPECIFIED, 0).into()
93 } else {
94 (Ipv6Addr::UNSPECIFIED, 0).into()
95 };
96 let udp = UdpSocket::bind(local_addr).await?;
97
98 Self::socket_connect(config, addr, udp).await
99 }
100
101 pub async fn socket_connect<A: ToSocketAddrs>(
102 config: Arc<KcpConfig>,
103 addr: A,
104 udp: UdpSocket,
105 ) -> io::Result<(KcpStream, SocketAddr)> {
106 let addr = lookup_host(addr)
107 .await?
108 .next()
109 .ok_or(io::ErrorKind::AddrNotAvailable)?;
110
111 KcpStream::connect::<_, BytesMut, _>(
112 config,
113 UdpStream::new(udp, addr),
114 futures::sink::drain(),
115 None,
116 )
117 .await
118 .map(|x| (x, addr))
119 }
120}
121
122impl Drop for KcpUdpStream {
123 fn drop(&mut self) {
124 self.token.cancel();
125 self.stream_rx.close();
126 }
127}
128
129impl Stream for KcpUdpStream {
130 type Item = (KcpStream, SocketAddr);
131
132 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133 self.stream_rx.poll_recv(cx)
134 }
135}
136
137struct Session {
140 conv: u32,
141 session_id: Bytes,
142 peer_addr: SocketAddr,
143 sender: Sender<BytesMut>,
144 stream_permit: Option<OwnedPermit<(KcpStream, SocketAddr)>>,
145 token: CancellationToken,
146 task: Option<JoinHandle<()>>,
147}
148
149enum Message {
150 Connect(KcpStream),
151 Disconnect { conv: u32 },
152}
153
154struct Task {
155 config: Arc<KcpConfig>,
156 conv_cache: ConvCache,
157 stream_tx: Sender<(KcpStream, SocketAddr)>,
158 msg_tx: UnboundedSender<Message>,
159 msg_rx: UnboundedReceiver<Message>,
160 pkt_tx: UnboundedSender<(Bytes, SocketAddr)>,
161 pkt_rx: UnboundedReceiver<(Bytes, SocketAddr)>,
162 token: CancellationToken,
163 is_closing: bool,
164
165 conv_map: LinkedHashMap<u32, Session>,
166 sid_map: LinkedHashMap<Bytes, u32>,
167}
168
169impl Task {
170 fn new(
171 config: Arc<KcpConfig>,
172 conv_cache: Option<ConvCache>,
173 stream_tx: Sender<(KcpStream, SocketAddr)>,
174 token: CancellationToken,
175 ) -> Self {
176 let (msg_tx, msg_rx) = unbounded_channel();
177 let (pkt_tx, pkt_rx) = unbounded_channel();
178 Self {
179 config,
180 conv_cache: conv_cache.unwrap_or_else(|| ConvCache::new(0, LISTENER_CONV_TIMEOUT)),
181 stream_tx,
182 msg_tx,
183 msg_rx,
184 pkt_tx,
185 pkt_rx,
186 token,
187 is_closing: false,
188 conv_map: LinkedHashMap::new(),
189 sid_map: LinkedHashMap::new(),
190 }
191 }
192
193 async fn run(mut self, udp: UdpSocket) {
194 let mut transport = UdpFramed::new(udp, BytesCodec::new());
195
196 loop {
197 if self.is_closing {
198 match self.msg_rx.try_recv() {
200 Ok(msg) => self.process_msg(msg).await,
201 Err(_) if self.conv_map.is_empty() => break,
202 _ => (),
203 }
204 }
205
206 select! {
207 x = transport.next() => {
208 let mut recved = x;
209 for _ in 0..LISTENER_TASK_LOOP {
210 match recved {
211 Some(Ok((packet, addr))) => {
212 if let Some(session) = self.get_session(&packet, &addr) {
213 let _ = session.sender.send(packet.clone()).await;
214 }
215 }
216 Some(Err(_)) => break,
217 None => {
218 self.is_closing = true;
219 self.token.cancel();
220 break;
221 }
222 }
223
224 match poll_immediate(transport.next()).await {
226 Some(x) => recved = x,
227 _ => break,
228 }
229 }
230 }
231
232 Some(item) = self.pkt_rx.recv() => {
233 let _ = transport.feed(item).await;
234 self.try_send(&mut transport, LISTENER_TASK_LOOP).await;
236 }
237
238 Some(msg) = self.msg_rx.recv() => self.process_msg(msg).await,
239
240 _ = self.token.cancelled(), if !self.is_closing => {
241 self.is_closing = true;
242 }
243 }
244 }
245
246 self.msg_rx.close();
247 self.pkt_rx.close();
248 self.try_send(&mut transport, usize::MAX).await;
249 }
250
251 async fn process_msg(&mut self, msg: Message) {
252 match msg {
253 Message::Connect(stream) => {
254 if let Some(session) = self.conv_map.get_mut(&stream.conv()) {
255 if let Some(task) = session.task.take() {
256 let _ = task.await;
257 }
258 if let Some(permit) = session.stream_permit.take() {
259 permit.send((stream, session.peer_addr));
260 }
261 }
262 }
263 Message::Disconnect { conv } => {
264 if let Some(session) = self.conv_map.remove(&conv) {
265 self.kill_session(session).await;
266 }
267 }
268 }
269 }
270
271 async fn try_send<S: Sink<(Bytes, SocketAddr)> + Unpin>(&mut self, sink: &mut S, max: usize) {
272 for _ in 0..max {
273 match self.pkt_rx.try_recv() {
274 Ok(item) => {
275 let _ = sink.feed(item).await;
276 }
277 _ => break,
278 }
279 }
280 let _ = sink.flush().await;
281 }
282
283 fn get_session(&mut self, packet: &[u8], peer_addr: &SocketAddr) -> Option<&Session> {
285 let pkt_conv = match Kcp::read_conv(packet) {
287 Some(x) => match self.conv_map.get(&x) {
288 Some(s) if &s.peer_addr == peer_addr => return self.conv_map.get(&x),
289 Some(_) => return None,
290 _ => x,
291 },
292 _ => return None,
293 };
294
295 let session_id = match KcpStream::read_session_id(packet, &self.config.session_key) {
297 Some(x) => x,
298 _ => return None,
299 };
300
301 if let Some(&conv) = self.sid_map.get(session_id) {
302 if conv == pkt_conv || pkt_conv == Kcp::SYN_CONV {
303 match self.conv_map.get(&conv) {
304 x @ Some(s) if &s.peer_addr == peer_addr => return x,
305 _ => (),
306 }
307 }
308 None
309 } else if self.is_closing
310 || pkt_conv != Kcp::SYN_CONV
311 || session_id.len() != self.config.session_id_len
312 {
313 None
315 } else {
316 let stream_permit = match self.stream_tx.clone().try_reserve_owned() {
318 Ok(x) => x,
319 _ => return None,
320 };
321
322 let conv = self.conv_cache.allocate(|x| self.conv_map.contains_key(x));
324
325 let (sender, receiver) = channel(self.config.snd_wnd as usize);
326 let token = self.token.child_token();
327
328 let session_id = Bytes::copy_from_slice(session_id);
329 self.sid_map.insert(session_id.clone(), conv);
330 self.conv_map.insert(
331 conv,
332 Session {
333 conv,
334 session_id,
335 peer_addr: *peer_addr,
336 sender,
337 token: token.clone(),
338 stream_permit: Some(stream_permit),
339 task: Some(tokio::spawn(Self::accept_stream(
340 self.config.clone(),
341 conv,
342 *peer_addr,
343 receiver,
344 self.pkt_tx.clone(),
345 self.msg_tx.clone(),
346 token,
347 ))),
348 },
349 );
350 self.conv_map.get(&conv)
351 }
352 }
353
354 async fn accept_stream(
355 config: Arc<KcpConfig>,
356 conv: u32,
357 peer_addr: SocketAddr,
358 receiver: Receiver<BytesMut>,
359 pkt_tx: UnboundedSender<(Bytes, SocketAddr)>,
360 msg_tx: UnboundedSender<Message>,
361 token: CancellationToken,
362 ) {
363 let disconnect = UnboundedSink::new(msg_tx.clone())
364 .with(move |conv: u32| ready(Ok::<_, io::Error>(Message::Disconnect { conv })));
365 if let Ok(stream) = KcpStream::accept(
366 config,
367 conv,
368 UdpMpscStream::new(Some(pkt_tx), receiver, peer_addr),
369 disconnect,
370 Some(token),
371 )
372 .await
373 {
374 let _ = msg_tx.send(Message::Connect(stream));
375 }
376 }
377
378 async fn kill_session(&mut self, mut session: Session) {
379 self.conv_cache.add(session.conv);
381 self.sid_map.remove(&session.session_id);
382 if let Some(task) = session.task.take() {
383 session.token.cancel();
384 let _ = task.await;
385 }
386 }
387}