1use std::{
2 pin::Pin,
3 task::ready,
4 task::{Context, Poll},
5};
6
7use bytes::{Bytes, BytesMut};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10use crate::endpoint::{ConnId, KcpEndpoint, KcpStreamReceiver};
11
12pub struct KcpStream {
13 sender: tokio_util::sync::PollSender<BytesMut>,
14 receiver: KcpStreamReceiver,
15 conn_id: ConnId,
16 conn_data: Bytes,
17
18 partial_recv_buf: Option<BytesMut>,
19}
20
21impl std::fmt::Debug for KcpStream {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("KcpStream")
24 .field("conn_id", &self.conn_id)
25 .finish()
26 }
27}
28
29impl KcpStream {
30 pub fn new(endpoint: &KcpEndpoint, conn_id: ConnId) -> Option<Self> {
31 let (sender, receiver) = endpoint.conn_sender_receiver(conn_id)?;
32 let conn_data = endpoint.conn_data(&conn_id)?;
33 Some(Self {
34 sender: tokio_util::sync::PollSender::new(sender),
35 receiver,
36 conn_id,
37 conn_data,
38
39 partial_recv_buf: None,
40 })
41 }
42
43 pub fn conn_data(&self) -> &Bytes {
44 &self.conn_data
45 }
46
47 pub fn conn_id(&self) -> ConnId {
48 self.conn_id
49 }
50}
51
52impl AsyncRead for KcpStream {
53 fn poll_read(
54 mut self: Pin<&mut Self>,
55 cx: &mut Context,
56 buf: &mut ReadBuf,
57 ) -> Poll<std::io::Result<()>> {
58 let mut partial_recved = false;
59 if let Some(partial_recv_buf) = &mut self.partial_recv_buf {
60 assert!(partial_recv_buf.len() > 0);
61 partial_recved = true;
62
63 let len = std::cmp::min(buf.remaining(), partial_recv_buf.len());
64 buf.put_slice(&partial_recv_buf.split_to(len));
65
66 if partial_recv_buf.is_empty() {
67 self.partial_recv_buf = None;
68 }
69
70 if buf.remaining() == 0 {
71 return Poll::Ready(Ok(()));
72 }
73 }
74
75 loop {
76 let recv_ret = self.receiver.poll_recv(cx);
77 match recv_ret {
78 Poll::Ready(Some(mut read_buf)) => {
79 partial_recved = true;
80
81 let len = std::cmp::min(buf.remaining(), read_buf.len());
82 buf.put_slice(&read_buf[..len]);
83
84 if len < read_buf.len() {
85 self.partial_recv_buf = Some(read_buf.split_off(len));
86 }
87
88 if buf.remaining() == 0 {
89 return Poll::Ready(Ok(()));
90 }
91 }
92 Poll::Ready(None) => {
93 if partial_recved {
94 return Poll::Ready(Ok(()));
95 } else {
96 return Poll::Ready(Err(std::io::Error::new(
97 std::io::ErrorKind::BrokenPipe,
98 "stream closed",
99 )));
100 }
101 }
102 Poll::Pending => {
103 if partial_recved {
104 return Poll::Ready(Ok(()));
105 } else {
106 return Poll::Pending;
107 }
108 }
109 }
110 }
111 }
112}
113
114impl AsyncWrite for KcpStream {
115 fn poll_write(
116 mut self: Pin<&mut Self>,
117 cx: &mut Context,
118 buf: &[u8],
119 ) -> Poll<std::io::Result<usize>> {
120 let mut ret = ready!(self.sender.poll_reserve(cx));
121 if ret.is_ok() {
122 ret = self.sender.send_item(BytesMut::from(buf));
123 }
124 match ret {
125 Ok(_) => Poll::Ready(Ok(buf.len())),
126 Err(_) => Poll::Ready(Err(std::io::Error::new(
127 std::io::ErrorKind::BrokenPipe,
128 "stream closed",
129 ))),
130 }
131 }
132
133 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
134 Poll::Ready(Ok(()))
135 }
136
137 fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<std::io::Result<()>> {
138 self.sender.close();
139 Poll::Ready(Ok(()))
140 }
141}