1use std::collections::HashMap;
2
3use bytes::{Buf, Bytes};
4use futures::{
5 FutureExt, SinkExt, StreamExt,
6 channel::{mpsc, oneshot},
7 future,
8};
9use thiserror::Error;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream};
11use tokio_tungstenite::{WebSocketStream, tungstenite as ws};
12use tokio_util::io::ReaderStream;
13
14#[derive(Debug, Error)]
16pub enum Error {
17 #[error("received invalid channel {0}")]
19 InvalidChannel(usize),
20
21 #[error("received initial frame with invalid size")]
23 InvalidInitialFrameSize,
24
25 #[error("invalid port mapping in initial frame, got {actual}, expected {expected}")]
28 InvalidPortMapping { actual: u16, expected: u16 },
29
30 #[error("failed to forward bytes from Pod: {0}")]
32 ForwardFromPod(#[source] futures::channel::mpsc::SendError),
33
34 #[error("failed to forward bytes to Pod: {0}")]
36 ForwardToPod(#[source] futures::channel::mpsc::SendError),
37
38 #[error("failed to write bytes from Pod: {0}")]
40 WriteBytesFromPod(#[source] std::io::Error),
41
42 #[error("failed to read bytes to send to Pod: {0}")]
44 ReadBytesToSend(#[source] std::io::Error),
45
46 #[error("received invalid error message from Pod: {0}")]
48 InvalidErrorMessage(#[source] std::string::FromUtf8Error),
49
50 #[error("failed to forward an error message {0:?}")]
52 ForwardErrorMessage(String),
53
54 #[error("failed to send a WebSocket message: {0}")]
56 SendWebSocketMessage(#[source] ws::Error),
57
58 #[error("failed to receive a WebSocket message: {0}")]
60 ReceiveWebSocketMessage(#[source] ws::Error),
61
62 #[error("failed to complete the background task: {0}")]
63 Spawn(#[source] tokio::task::JoinError),
64
65 #[error("failed to shutdown write to Pod channel: {0}")]
67 Shutdown(#[source] std::io::Error),
68}
69
70type ErrorReceiver = oneshot::Receiver<String>;
71type ErrorSender = oneshot::Sender<String>;
72
73enum Message {
75 FromPod(u8, Bytes),
76 ToPod(u8, Bytes),
77 FromPodClose,
78 ToPodClose(u8),
79}
80
81pub struct Portforwarder {
87 ports: HashMap<u16, DuplexStream>,
88 errors: HashMap<u16, ErrorReceiver>,
89 task: tokio::task::JoinHandle<Result<(), Error>>,
90}
91
92impl Portforwarder {
93 pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
94 where
95 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
96 {
97 let mut ports = HashMap::with_capacity(port_nums.len());
98 let mut error_rxs = HashMap::with_capacity(port_nums.len());
99 let mut error_txs = Vec::with_capacity(port_nums.len());
100 let mut task_ios = Vec::with_capacity(port_nums.len());
101 for port in port_nums.iter() {
102 let (a, b) = tokio::io::duplex(1024 * 1024);
103 ports.insert(*port, a);
104 task_ios.push(b);
105
106 let (tx, rx) = oneshot::channel();
107 error_rxs.insert(*port, rx);
108 error_txs.push(Some(tx));
109 }
110 let task = tokio::spawn(start_message_loop(
111 stream,
112 port_nums.to_vec(),
113 task_ios,
114 error_txs,
115 ));
116
117 Portforwarder {
118 ports,
119 errors: error_rxs,
120 task,
121 }
122 }
123
124 #[inline]
128 pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin + use<>> {
129 self.ports.remove(&port)
130 }
131
132 #[inline]
137 pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>> + use<>> {
138 self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
139 }
140
141 #[inline]
143 pub fn abort(&self) {
144 self.task.abort();
145 }
146
147 pub async fn join(self) -> Result<(), Error> {
149 let Self {
150 mut ports,
151 mut errors,
152 task,
153 } = self;
154 ports.clear();
157 errors.clear();
158 task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
159 }
160}
161
162async fn start_message_loop<S>(
163 stream: WebSocketStream<S>,
164 ports: Vec<u16>,
165 duplexes: Vec<DuplexStream>,
166 error_senders: Vec<Option<ErrorSender>>,
167) -> Result<(), Error>
168where
169 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
170{
171 let mut writers = Vec::new();
172 let mut loops = Vec::with_capacity(ports.len() + 2);
176 let (sender, receiver) = mpsc::channel::<Message>(1);
178 for (i, (r, w)) in duplexes.into_iter().map(tokio::io::split).enumerate() {
179 writers.push(w);
180 let ch = 2 * (i as u8);
182 loops.push(to_pod_loop(ch, r, sender.clone()).boxed());
183 }
184
185 let (ws_sink, ws_stream) = stream.split();
186 loops.push(from_pod_loop(ws_stream, sender).boxed());
187 loops.push(forwarder_loop(&ports, receiver, ws_sink, writers, error_senders).boxed());
188
189 future::try_join_all(loops).await.map(|_| ())
190}
191
192async fn to_pod_loop(
193 ch: u8,
194 reader: tokio::io::ReadHalf<DuplexStream>,
195 mut sender: mpsc::Sender<Message>,
196) -> Result<(), Error> {
197 let mut read_stream = ReaderStream::new(reader);
198 while let Some(bytes) = read_stream
199 .next()
200 .await
201 .transpose()
202 .map_err(Error::ReadBytesToSend)?
203 {
204 if !bytes.is_empty() {
205 sender
206 .send(Message::ToPod(ch, bytes))
207 .await
208 .map_err(Error::ForwardToPod)?;
209 }
210 }
211 sender
212 .send(Message::ToPodClose(ch))
213 .await
214 .map_err(Error::ForwardToPod)?;
215 Ok(())
216}
217
218async fn from_pod_loop<S>(
219 mut ws_stream: futures::stream::SplitStream<WebSocketStream<S>>,
220 mut sender: mpsc::Sender<Message>,
221) -> Result<(), Error>
222where
223 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
224{
225 while let Some(msg) = ws_stream
226 .next()
227 .await
228 .transpose()
229 .map_err(Error::ReceiveWebSocketMessage)?
230 {
231 match msg {
232 ws::Message::Binary(mut bytes) if bytes.len() > 1 => {
233 let ch = bytes.split_to(1)[0];
234 sender
235 .send(Message::FromPod(ch, bytes))
236 .await
237 .map_err(Error::ForwardFromPod)?;
238 }
239 message if message.is_close() => {
240 sender
241 .send(Message::FromPodClose)
242 .await
243 .map_err(Error::ForwardFromPod)?;
244 break;
245 }
246 _ => {}
248 }
249 }
250 Ok(())
251}
252
253async fn forwarder_loop<S>(
258 ports: &[u16],
259 mut receiver: mpsc::Receiver<Message>,
260 mut ws_sink: futures::stream::SplitSink<WebSocketStream<S>, ws::Message>,
261 mut writers: Vec<tokio::io::WriteHalf<DuplexStream>>,
262 mut error_senders: Vec<Option<ErrorSender>>,
263) -> Result<(), Error>
264where
265 S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
266{
267 #[derive(Default, Clone)]
268 struct ChannelState {
269 initialized: bool,
271 shutdown: bool,
273 }
274 let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
275 let mut closed_ports = 0;
276 let mut socket_shutdown = false;
277 while let Some(msg) = receiver.next().await {
278 match msg {
279 Message::FromPod(ch, mut bytes) => {
280 let ch = ch as usize;
281 let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
282
283 let port_index = ch / 2;
284 if !channel.initialized {
286 if bytes.len() != 2 {
288 return Err(Error::InvalidInitialFrameSize);
289 }
290
291 let port = bytes.get_u16_le();
292 if port != ports[port_index] {
293 return Err(Error::InvalidPortMapping {
294 actual: port,
295 expected: ports[port_index],
296 });
297 }
298
299 channel.initialized = true;
300 continue;
301 }
302
303 if !ch.is_multiple_of(2) {
305 if let Some(sender) = error_senders[port_index].take() {
307 let s = String::from_utf8(bytes.into_iter().collect())
308 .map_err(Error::InvalidErrorMessage)?;
309 sender.send(s).map_err(Error::ForwardErrorMessage)?;
310 }
311 } else if !channel.shutdown {
312 writers[port_index]
313 .write_all(&bytes)
314 .await
315 .map_err(Error::WriteBytesFromPod)?;
316 }
317 }
318
319 Message::ToPod(ch, bytes) => {
320 let mut bin = Vec::with_capacity(bytes.len() + 1);
321 bin.push(ch);
322 bin.extend(bytes);
323 ws_sink
324 .send(ws::Message::binary(bin))
325 .await
326 .map_err(Error::SendWebSocketMessage)?;
327 }
328 Message::ToPodClose(ch) => {
329 let ch = ch as usize;
330 let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
331 let port_index = ch / 2;
332
333 if !channel.shutdown {
334 writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
335 channel.shutdown = true;
336
337 closed_ports += 1;
338 }
339 }
340 Message::FromPodClose => {
341 for writer in &mut writers {
342 writer.shutdown().await.map_err(Error::Shutdown)?;
343 }
344 }
345 }
346
347 if closed_ports == ports.len() && !socket_shutdown {
348 ws_sink
349 .send(ws::Message::Close(None))
350 .await
351 .map_err(Error::SendWebSocketMessage)?;
352 socket_shutdown = true;
353 }
354 }
355 Ok(())
356}