kcp_sys/
stream.rs

1use std::{
2    pin::Pin,
3    task::ready,
4    task::{Context, Poll},
5};
6
7use bytes::{Bytes, BytesMut};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10use crate::endpoint::{ConnId, KcpEndpoint, KcpStreamReceiver};
11
12pub struct KcpStream {
13    sender: tokio_util::sync::PollSender<BytesMut>,
14    receiver: KcpStreamReceiver,
15    conn_id: ConnId,
16    conn_data: Bytes,
17
18    partial_recv_buf: Option<BytesMut>,
19}
20
21impl std::fmt::Debug for KcpStream {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("KcpStream")
24            .field("conn_id", &self.conn_id)
25            .finish()
26    }
27}
28
29impl KcpStream {
30    pub fn new(endpoint: &KcpEndpoint, conn_id: ConnId) -> Option<Self> {
31        let (sender, receiver) = endpoint.conn_sender_receiver(conn_id)?;
32        let conn_data = endpoint.conn_data(&conn_id)?;
33        Some(Self {
34            sender: tokio_util::sync::PollSender::new(sender),
35            receiver,
36            conn_id,
37            conn_data,
38
39            partial_recv_buf: None,
40        })
41    }
42
43    pub fn conn_data(&self) -> &Bytes {
44        &self.conn_data
45    }
46
47    pub fn conn_id(&self) -> ConnId {
48        self.conn_id
49    }
50}
51
52impl AsyncRead for KcpStream {
53    fn poll_read(
54        mut self: Pin<&mut Self>,
55        cx: &mut Context,
56        buf: &mut ReadBuf,
57    ) -> Poll<std::io::Result<()>> {
58        let mut partial_recved = false;
59        if let Some(partial_recv_buf) = &mut self.partial_recv_buf {
60            assert!(partial_recv_buf.len() > 0);
61            partial_recved = true;
62
63            let len = std::cmp::min(buf.remaining(), partial_recv_buf.len());
64            buf.put_slice(&partial_recv_buf.split_to(len));
65
66            if partial_recv_buf.is_empty() {
67                self.partial_recv_buf = None;
68            }
69
70            if buf.remaining() == 0 {
71                return Poll::Ready(Ok(()));
72            }
73        }
74
75        loop {
76            let recv_ret = self.receiver.poll_recv(cx);
77            match recv_ret {
78                Poll::Ready(Some(mut read_buf)) => {
79                    partial_recved = true;
80
81                    let len = std::cmp::min(buf.remaining(), read_buf.len());
82                    buf.put_slice(&read_buf[..len]);
83
84                    if len < read_buf.len() {
85                        self.partial_recv_buf = Some(read_buf.split_off(len));
86                    }
87
88                    if buf.remaining() == 0 {
89                        return Poll::Ready(Ok(()));
90                    }
91                }
92                Poll::Ready(None) => {
93                    if partial_recved {
94                        return Poll::Ready(Ok(()));
95                    } else {
96                        return Poll::Ready(Err(std::io::Error::new(
97                            std::io::ErrorKind::BrokenPipe,
98                            "stream closed",
99                        )));
100                    }
101                }
102                Poll::Pending => {
103                    if partial_recved {
104                        return Poll::Ready(Ok(()));
105                    } else {
106                        return Poll::Pending;
107                    }
108                }
109            }
110        }
111    }
112}
113
114impl AsyncWrite for KcpStream {
115    fn poll_write(
116        mut self: Pin<&mut Self>,
117        cx: &mut Context,
118        buf: &[u8],
119    ) -> Poll<std::io::Result<usize>> {
120        let mut ret = ready!(self.sender.poll_reserve(cx));
121        if ret.is_ok() {
122            ret = self.sender.send_item(BytesMut::from(buf));
123        }
124        match ret {
125            Ok(_) => Poll::Ready(Ok(buf.len())),
126            Err(_) => Poll::Ready(Err(std::io::Error::new(
127                std::io::ErrorKind::BrokenPipe,
128                "stream closed",
129            ))),
130        }
131    }
132
133    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
134        Poll::Ready(Ok(()))
135    }
136
137    fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
138        self.sender.close();
139        Poll::Ready(Ok(()))
140    }
141}