rapace_core/
tunnel_stream.rs

1//! First-class bidirectional tunnel streams.
2//!
3//! A tunnel is a bidirectional byte stream multiplexed over an RPC channel ID.
4//! This wraps the low-level `register_tunnel`/`send_chunk`/`close_tunnel` APIs into
5//! an ergonomic `AsyncRead + AsyncWrite` type.
6
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use bytes::Bytes;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::sync::mpsc;
14
15use crate::session::RpcSession;
16use crate::{RpcError, TunnelChunk, parse_error_payload};
17
18/// A handle identifying a tunnel channel.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TunnelHandle {
21    pub channel_id: u32,
22}
23
24/// A bidirectional byte stream over a tunnel channel.
25///
26/// - Reads consume incoming tunnel chunks from `RpcSession::register_tunnel`.
27/// - Writes send outgoing tunnel chunks via `RpcSession::send_chunk`.
28/// - `poll_shutdown` sends an EOS via `RpcSession::close_tunnel`.
29pub struct TunnelStream {
30    channel_id: u32,
31    session: Arc<RpcSession>,
32    rx: mpsc::Receiver<TunnelChunk>,
33
34    read_buf: Bytes,
35    read_eof: bool,
36    read_eos_after_buf: bool,
37    logged_first_read: bool,
38    logged_first_write: bool,
39    logged_read_eof: bool,
40    logged_shutdown: bool,
41
42    pending_send: Option<PendingSend>,
43    write_closed: bool,
44}
45
46type PendingSend =
47    Pin<Box<dyn std::future::Future<Output = Result<(), RpcError>> + Send + 'static>>;
48
49impl TunnelStream {
50    /// Create a new tunnel stream for an existing `channel_id`.
51    ///
52    /// This registers a tunnel receiver immediately, so the peer can start sending.
53    pub fn new(session: Arc<RpcSession>, channel_id: u32) -> Self {
54        let rx = session.register_tunnel(channel_id);
55        tracing::debug!(channel_id, "tunnel stream created");
56        Self {
57            channel_id,
58            session,
59            rx,
60            read_buf: Bytes::new(),
61            read_eof: false,
62            read_eos_after_buf: false,
63            pending_send: None,
64            write_closed: false,
65            logged_first_read: false,
66            logged_first_write: false,
67            logged_read_eof: false,
68            logged_shutdown: false,
69        }
70    }
71
72    /// Allocate a fresh tunnel channel ID and return a stream for it.
73    pub fn open(session: Arc<RpcSession>) -> (TunnelHandle, Self) {
74        let channel_id = session.next_channel_id();
75        tracing::debug!(channel_id, "tunnel stream open");
76        let stream = Self::new(session, channel_id);
77        (TunnelHandle { channel_id }, stream)
78    }
79
80    pub fn channel_id(&self) -> u32 {
81        self.channel_id
82    }
83}
84
85impl Drop for TunnelStream {
86    fn drop(&mut self) {
87        tracing::debug!(
88            channel_id = self.channel_id,
89            write_closed = self.write_closed,
90            read_eof = self.read_eof,
91            "tunnel stream dropped"
92        );
93        // Always unregister locally to avoid leaking an entry in `RpcSession::tunnels`
94        // when the peer stops sending without an EOS.
95        self.session.unregister_tunnel(self.channel_id);
96
97        // Best-effort half-close if the write side wasn't cleanly shut down.
98        // This avoids leaving the peer waiting forever.
99        if !self.write_closed {
100            let session = self.session.clone();
101            let channel_id = self.channel_id;
102            tokio::spawn(async move {
103                let _ = session.close_tunnel(channel_id).await;
104            });
105        }
106    }
107}
108
109impl AsyncRead for TunnelStream {
110    fn poll_read(
111        mut self: Pin<&mut Self>,
112        cx: &mut Context<'_>,
113        buf: &mut ReadBuf<'_>,
114    ) -> Poll<std::io::Result<()>> {
115        if self.read_eof {
116            return Poll::Ready(Ok(()));
117        }
118
119        // Drain buffered bytes first.
120        if !self.read_buf.is_empty() {
121            let to_copy = std::cmp::min(self.read_buf.len(), buf.remaining());
122            buf.put_slice(&self.read_buf.split_to(to_copy));
123
124            if self.read_buf.is_empty() && self.read_eos_after_buf {
125                self.read_eof = true;
126            }
127
128            return Poll::Ready(Ok(()));
129        }
130
131        // Buffer empty: poll for the next chunk.
132        match Pin::new(&mut self.rx).poll_recv(cx) {
133            Poll::Pending => Poll::Pending,
134            Poll::Ready(None) => {
135                self.read_eof = true;
136                if !self.logged_read_eof {
137                    self.logged_read_eof = true;
138                    tracing::debug!(channel_id = self.channel_id, "tunnel read EOF (rx closed)");
139                }
140                Poll::Ready(Ok(()))
141            }
142            Poll::Ready(Some(chunk)) => {
143                if !self.logged_first_read {
144                    self.logged_first_read = true;
145                    tracing::debug!(
146                        channel_id = self.channel_id,
147                        payload_len = chunk.payload_bytes().len(),
148                        is_eos = chunk.is_eos(),
149                        is_error = chunk.is_error(),
150                        "tunnel read first chunk"
151                    );
152                }
153                if chunk.is_error() {
154                    let err = parse_error_payload(chunk.payload_bytes());
155                    let (kind, msg) = match err {
156                        RpcError::Status { code, message } => {
157                            (std::io::ErrorKind::Other, format!("{code:?}: {message}"))
158                        }
159                        RpcError::Transport(e) => {
160                            (std::io::ErrorKind::BrokenPipe, format!("{e:?}"))
161                        }
162                        RpcError::Cancelled => {
163                            (std::io::ErrorKind::Interrupted, "cancelled".into())
164                        }
165                        RpcError::DeadlineExceeded => {
166                            (std::io::ErrorKind::TimedOut, "deadline exceeded".into())
167                        }
168                    };
169                    return Poll::Ready(Err(std::io::Error::new(kind, msg)));
170                }
171
172                let payload = chunk.payload_bytes();
173                if chunk.is_eos() && payload.is_empty() {
174                    self.read_eof = true;
175                    if !self.logged_read_eof {
176                        self.logged_read_eof = true;
177                        tracing::debug!(
178                            channel_id = self.channel_id,
179                            "tunnel read EOF (empty EOS)"
180                        );
181                    }
182                    return Poll::Ready(Ok(()));
183                }
184
185                self.read_buf = Bytes::copy_from_slice(payload);
186                self.read_eos_after_buf = chunk.is_eos();
187
188                // Recurse once to copy into ReadBuf.
189                self.poll_read(cx, buf)
190            }
191        }
192    }
193}
194
195impl AsyncWrite for TunnelStream {
196    fn poll_write(
197        mut self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        data: &[u8],
200    ) -> Poll<std::io::Result<usize>> {
201        if self.write_closed {
202            return Poll::Ready(Err(std::io::Error::new(
203                std::io::ErrorKind::BrokenPipe,
204                "tunnel write side closed",
205            )));
206        }
207
208        // Drive any pending send first.
209        if let Some(fut) = self.pending_send.as_mut() {
210            match fut.as_mut().poll(cx) {
211                Poll::Ready(Ok(())) => self.pending_send = None,
212                Poll::Ready(Err(e)) => {
213                    self.pending_send = None;
214                    return Poll::Ready(Err(std::io::Error::new(
215                        std::io::ErrorKind::BrokenPipe,
216                        format!("send failed: {e:?}"),
217                    )));
218                }
219                Poll::Pending => return Poll::Pending,
220            }
221        }
222
223        if data.is_empty() {
224            return Poll::Ready(Ok(0));
225        }
226
227        let channel_id = self.channel_id;
228        if !self.logged_first_write {
229            self.logged_first_write = true;
230            tracing::debug!(channel_id, payload_len = data.len(), "tunnel first write");
231        }
232        let session = self.session.clone();
233        let bytes = data.to_vec();
234        let len = bytes.len();
235        self.pending_send = Some(Box::pin(async move {
236            session.send_chunk(channel_id, bytes).await
237        }));
238
239        // Immediately poll the future once.
240        if let Some(fut) = self.pending_send.as_mut() {
241            match fut.as_mut().poll(cx) {
242                Poll::Ready(Ok(())) => {
243                    self.pending_send = None;
244                    Poll::Ready(Ok(len))
245                }
246                Poll::Ready(Err(e)) => {
247                    self.pending_send = None;
248                    Poll::Ready(Err(std::io::Error::new(
249                        std::io::ErrorKind::BrokenPipe,
250                        format!("send failed: {e:?}"),
251                    )))
252                }
253                Poll::Pending => Poll::Pending,
254            }
255        } else {
256            Poll::Ready(Ok(len))
257        }
258    }
259
260    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
261        if let Some(fut) = self.pending_send.as_mut() {
262            match fut.as_mut().poll(cx) {
263                Poll::Ready(Ok(())) => {
264                    self.pending_send = None;
265                    Poll::Ready(Ok(()))
266                }
267                Poll::Ready(Err(e)) => {
268                    self.pending_send = None;
269                    Poll::Ready(Err(std::io::Error::new(
270                        std::io::ErrorKind::BrokenPipe,
271                        format!("send failed: {e:?}"),
272                    )))
273                }
274                Poll::Pending => Poll::Pending,
275            }
276        } else {
277            Poll::Ready(Ok(()))
278        }
279    }
280
281    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
282        if self.write_closed {
283            return Poll::Ready(Ok(()));
284        }
285
286        match self.as_mut().poll_flush(cx) {
287            Poll::Ready(Ok(())) => {}
288            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
289            Poll::Pending => return Poll::Pending,
290        }
291
292        self.write_closed = true;
293        if !self.logged_shutdown {
294            self.logged_shutdown = true;
295            tracing::debug!(channel_id = self.channel_id, "tunnel shutdown");
296        }
297        let channel_id = self.channel_id;
298        let session = self.session.clone();
299        tokio::spawn(async move {
300            let _ = session.close_tunnel(channel_id).await;
301        });
302        Poll::Ready(Ok(()))
303    }
304}