1use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14
15use crate::session::RpcSession;
16use crate::{RpcError, TunnelChunk, parse_error_payload};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TunnelHandle {
21 pub channel_id: u32,
22}
23
24pub struct TunnelStream {
30 channel_id: u32,
31 session: Arc<RpcSession>,
32 rx: mpsc::Receiver<TunnelChunk>,
33
34 read_buf: Bytes,
35 read_eof: bool,
36 read_eos_after_buf: bool,
37 logged_first_read: bool,
38 logged_first_write: bool,
39 logged_read_eof: bool,
40 logged_shutdown: bool,
41
42 pending_send: Option<PendingSend>,
43 write_closed: bool,
44}
45
46type PendingSend =
47 Pin<Box<dyn std::future::Future<Output = Result<(), RpcError>> + Send + 'static>>;
48
49impl TunnelStream {
50 pub fn new(session: Arc<RpcSession>, channel_id: u32) -> Self {
54 let rx = session.register_tunnel(channel_id);
55 tracing::debug!(channel_id, "tunnel stream created");
56 Self {
57 channel_id,
58 session,
59 rx,
60 read_buf: Bytes::new(),
61 read_eof: false,
62 read_eos_after_buf: false,
63 pending_send: None,
64 write_closed: false,
65 logged_first_read: false,
66 logged_first_write: false,
67 logged_read_eof: false,
68 logged_shutdown: false,
69 }
70 }
71
72 pub fn open(session: Arc<RpcSession>) -> (TunnelHandle, Self) {
74 let channel_id = session.next_channel_id();
75 tracing::debug!(channel_id, "tunnel stream open");
76 let stream = Self::new(session, channel_id);
77 (TunnelHandle { channel_id }, stream)
78 }
79
80 pub fn channel_id(&self) -> u32 {
81 self.channel_id
82 }
83}
84
85impl Drop for TunnelStream {
86 fn drop(&mut self) {
87 tracing::debug!(
88 channel_id = self.channel_id,
89 write_closed = self.write_closed,
90 read_eof = self.read_eof,
91 "tunnel stream dropped"
92 );
93 self.session.unregister_tunnel(self.channel_id);
96
97 if !self.write_closed {
100 let session = self.session.clone();
101 let channel_id = self.channel_id;
102 tokio::spawn(async move {
103 let _ = session.close_tunnel(channel_id).await;
104 });
105 }
106 }
107}
108
109impl AsyncRead for TunnelStream {
110 fn poll_read(
111 mut self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 buf: &mut ReadBuf<'_>,
114 ) -> Poll<std::io::Result<()>> {
115 if self.read_eof {
116 return Poll::Ready(Ok(()));
117 }
118
119 if !self.read_buf.is_empty() {
121 let to_copy = std::cmp::min(self.read_buf.len(), buf.remaining());
122 buf.put_slice(&self.read_buf.split_to(to_copy));
123
124 if self.read_buf.is_empty() && self.read_eos_after_buf {
125 self.read_eof = true;
126 }
127
128 return Poll::Ready(Ok(()));
129 }
130
131 match Pin::new(&mut self.rx).poll_recv(cx) {
133 Poll::Pending => Poll::Pending,
134 Poll::Ready(None) => {
135 self.read_eof = true;
136 if !self.logged_read_eof {
137 self.logged_read_eof = true;
138 tracing::debug!(channel_id = self.channel_id, "tunnel read EOF (rx closed)");
139 }
140 Poll::Ready(Ok(()))
141 }
142 Poll::Ready(Some(chunk)) => {
143 if !self.logged_first_read {
144 self.logged_first_read = true;
145 tracing::debug!(
146 channel_id = self.channel_id,
147 payload_len = chunk.payload_bytes().len(),
148 is_eos = chunk.is_eos(),
149 is_error = chunk.is_error(),
150 "tunnel read first chunk"
151 );
152 }
153 if chunk.is_error() {
154 let err = parse_error_payload(chunk.payload_bytes());
155 let (kind, msg) = match err {
156 RpcError::Status { code, message } => {
157 (std::io::ErrorKind::Other, format!("{code:?}: {message}"))
158 }
159 RpcError::Transport(e) => {
160 (std::io::ErrorKind::BrokenPipe, format!("{e:?}"))
161 }
162 RpcError::Cancelled => {
163 (std::io::ErrorKind::Interrupted, "cancelled".into())
164 }
165 RpcError::DeadlineExceeded => {
166 (std::io::ErrorKind::TimedOut, "deadline exceeded".into())
167 }
168 };
169 return Poll::Ready(Err(std::io::Error::new(kind, msg)));
170 }
171
172 let payload = chunk.payload_bytes();
173 if chunk.is_eos() && payload.is_empty() {
174 self.read_eof = true;
175 if !self.logged_read_eof {
176 self.logged_read_eof = true;
177 tracing::debug!(
178 channel_id = self.channel_id,
179 "tunnel read EOF (empty EOS)"
180 );
181 }
182 return Poll::Ready(Ok(()));
183 }
184
185 self.read_buf = Bytes::copy_from_slice(payload);
186 self.read_eos_after_buf = chunk.is_eos();
187
188 self.poll_read(cx, buf)
190 }
191 }
192 }
193}
194
195impl AsyncWrite for TunnelStream {
196 fn poll_write(
197 mut self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 data: &[u8],
200 ) -> Poll<std::io::Result<usize>> {
201 if self.write_closed {
202 return Poll::Ready(Err(std::io::Error::new(
203 std::io::ErrorKind::BrokenPipe,
204 "tunnel write side closed",
205 )));
206 }
207
208 if let Some(fut) = self.pending_send.as_mut() {
210 match fut.as_mut().poll(cx) {
211 Poll::Ready(Ok(())) => self.pending_send = None,
212 Poll::Ready(Err(e)) => {
213 self.pending_send = None;
214 return Poll::Ready(Err(std::io::Error::new(
215 std::io::ErrorKind::BrokenPipe,
216 format!("send failed: {e:?}"),
217 )));
218 }
219 Poll::Pending => return Poll::Pending,
220 }
221 }
222
223 if data.is_empty() {
224 return Poll::Ready(Ok(0));
225 }
226
227 let channel_id = self.channel_id;
228 if !self.logged_first_write {
229 self.logged_first_write = true;
230 tracing::debug!(channel_id, payload_len = data.len(), "tunnel first write");
231 }
232 let session = self.session.clone();
233 let bytes = data.to_vec();
234 let len = bytes.len();
235 self.pending_send = Some(Box::pin(async move {
236 session.send_chunk(channel_id, bytes).await
237 }));
238
239 if let Some(fut) = self.pending_send.as_mut() {
241 match fut.as_mut().poll(cx) {
242 Poll::Ready(Ok(())) => {
243 self.pending_send = None;
244 Poll::Ready(Ok(len))
245 }
246 Poll::Ready(Err(e)) => {
247 self.pending_send = None;
248 Poll::Ready(Err(std::io::Error::new(
249 std::io::ErrorKind::BrokenPipe,
250 format!("send failed: {e:?}"),
251 )))
252 }
253 Poll::Pending => Poll::Pending,
254 }
255 } else {
256 Poll::Ready(Ok(len))
257 }
258 }
259
260 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
261 if let Some(fut) = self.pending_send.as_mut() {
262 match fut.as_mut().poll(cx) {
263 Poll::Ready(Ok(())) => {
264 self.pending_send = None;
265 Poll::Ready(Ok(()))
266 }
267 Poll::Ready(Err(e)) => {
268 self.pending_send = None;
269 Poll::Ready(Err(std::io::Error::new(
270 std::io::ErrorKind::BrokenPipe,
271 format!("send failed: {e:?}"),
272 )))
273 }
274 Poll::Pending => Poll::Pending,
275 }
276 } else {
277 Poll::Ready(Ok(()))
278 }
279 }
280
281 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
282 if self.write_closed {
283 return Poll::Ready(Ok(()));
284 }
285
286 match self.as_mut().poll_flush(cx) {
287 Poll::Ready(Ok(())) => {}
288 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
289 Poll::Pending => return Poll::Pending,
290 }
291
292 self.write_closed = true;
293 if !self.logged_shutdown {
294 self.logged_shutdown = true;
295 tracing::debug!(channel_id = self.channel_id, "tunnel shutdown");
296 }
297 let channel_id = self.channel_id;
298 let session = self.session.clone();
299 tokio::spawn(async move {
300 let _ = session.close_tunnel(channel_id).await;
301 });
302 Poll::Ready(Ok(()))
303 }
304}