kube_client/api/
portforward.rs

1use std::collections::HashMap;
2
3use bytes::{Buf, Bytes};
4use futures::{
5    FutureExt, SinkExt, StreamExt,
6    channel::{mpsc, oneshot},
7    future,
8};
9use thiserror::Error;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream};
11use tokio_tungstenite::{WebSocketStream, tungstenite as ws};
12use tokio_util::io::ReaderStream;
13
14/// Errors from Portforwarder.
15#[derive(Debug, Error)]
16pub enum Error {
17    /// Received invalid channel in WebSocket message.
18    #[error("received invalid channel {0}")]
19    InvalidChannel(usize),
20
21    /// Received initial frame with invalid size. The initial frame must be 3 bytes, including the channel prefix.
22    #[error("received initial frame with invalid size")]
23    InvalidInitialFrameSize,
24
25    /// Received initial frame with invalid port mapping.
26    /// The port included in the initial frame did not match the port number associated with the channel.
27    #[error("invalid port mapping in initial frame, got {actual}, expected {expected}")]
28    InvalidPortMapping { actual: u16, expected: u16 },
29
30    /// Failed to forward bytes from Pod.
31    #[error("failed to forward bytes from Pod: {0}")]
32    ForwardFromPod(#[source] futures::channel::mpsc::SendError),
33
34    /// Failed to forward bytes to Pod.
35    #[error("failed to forward bytes to Pod: {0}")]
36    ForwardToPod(#[source] futures::channel::mpsc::SendError),
37
38    /// Failed to write bytes from Pod.
39    #[error("failed to write bytes from Pod: {0}")]
40    WriteBytesFromPod(#[source] std::io::Error),
41
42    /// Failed to read bytes to send to Pod.
43    #[error("failed to read bytes to send to Pod: {0}")]
44    ReadBytesToSend(#[source] std::io::Error),
45
46    /// Received an error message from pod that is not a valid UTF-8.
47    #[error("received invalid error message from Pod: {0}")]
48    InvalidErrorMessage(#[source] std::string::FromUtf8Error),
49
50    /// Failed to forward an error message from pod.
51    #[error("failed to forward an error message {0:?}")]
52    ForwardErrorMessage(String),
53
54    /// Failed to send a WebSocket message to the server.
55    #[error("failed to send a WebSocket message: {0}")]
56    SendWebSocketMessage(#[source] ws::Error),
57
58    /// Failed to receive a WebSocket message from the server.
59    #[error("failed to receive a WebSocket message: {0}")]
60    ReceiveWebSocketMessage(#[source] ws::Error),
61
62    #[error("failed to complete the background task: {0}")]
63    Spawn(#[source] tokio::task::JoinError),
64
65    /// Failed to shutdown a pod writer channel.
66    #[error("failed to shutdown write to Pod channel: {0}")]
67    Shutdown(#[source] std::io::Error),
68}
69
70type ErrorReceiver = oneshot::Receiver<String>;
71type ErrorSender = oneshot::Sender<String>;
72
73// Internal message used by the futures to communicate with each other.
74enum Message {
75    FromPod(u8, Bytes),
76    ToPod(u8, Bytes),
77    FromPodClose,
78    ToPodClose(u8),
79}
80
81/// Manages port-forwarded streams.
82///
83/// Provides `AsyncRead + AsyncWrite` for each port and **does not** bind to local ports.  Error
84/// channel for each port is only written by the server when there's an exception and
85/// the port cannot be used (didn't initialize or can't be used anymore).
86pub struct Portforwarder {
87    ports: HashMap<u16, DuplexStream>,
88    errors: HashMap<u16, ErrorReceiver>,
89    task: tokio::task::JoinHandle<Result<(), Error>>,
90}
91
92impl Portforwarder {
93    pub(crate) fn new<S>(stream: WebSocketStream<S>, port_nums: &[u16]) -> Self
94    where
95        S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
96    {
97        let mut ports = HashMap::with_capacity(port_nums.len());
98        let mut error_rxs = HashMap::with_capacity(port_nums.len());
99        let mut error_txs = Vec::with_capacity(port_nums.len());
100        let mut task_ios = Vec::with_capacity(port_nums.len());
101        for port in port_nums.iter() {
102            let (a, b) = tokio::io::duplex(1024 * 1024);
103            ports.insert(*port, a);
104            task_ios.push(b);
105
106            let (tx, rx) = oneshot::channel();
107            error_rxs.insert(*port, rx);
108            error_txs.push(Some(tx));
109        }
110        let task = tokio::spawn(start_message_loop(
111            stream,
112            port_nums.to_vec(),
113            task_ios,
114            error_txs,
115        ));
116
117        Portforwarder {
118            ports,
119            errors: error_rxs,
120            task,
121        }
122    }
123
124    /// Take a port stream by the port on the target resource.
125    ///
126    /// A value is returned at most once per port.
127    #[inline]
128    pub fn take_stream(&mut self, port: u16) -> Option<impl AsyncRead + AsyncWrite + Unpin + use<>> {
129        self.ports.remove(&port)
130    }
131
132    /// Take a future that resolves with any error message or when the error sender is dropped.
133    /// When the future resolves, the port should be considered no longer usable.
134    ///
135    /// A value is returned at most once per port.
136    #[inline]
137    pub fn take_error(&mut self, port: u16) -> Option<impl Future<Output = Option<String>> + use<>> {
138        self.errors.remove(&port).map(|recv| recv.map(|res| res.ok()))
139    }
140
141    /// Abort the background task, causing port forwards to fail.
142    #[inline]
143    pub fn abort(&self) {
144        self.task.abort();
145    }
146
147    /// Waits for port forwarding task to complete.
148    pub async fn join(self) -> Result<(), Error> {
149        let Self {
150            mut ports,
151            mut errors,
152            task,
153        } = self;
154        // Start by terminating any streams that have not yet been taken
155        // since they would otherwise keep the connection open indefinitely
156        ports.clear();
157        errors.clear();
158        task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
159    }
160}
161
162async fn start_message_loop<S>(
163    stream: WebSocketStream<S>,
164    ports: Vec<u16>,
165    duplexes: Vec<DuplexStream>,
166    error_senders: Vec<Option<ErrorSender>>,
167) -> Result<(), Error>
168where
169    S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
170{
171    let mut writers = Vec::new();
172    // Loops to run concurrently.
173    // We can spawn tasks to run `to_pod_loop` in parallel and flatten the errors, but the other 2 loops
174    // are over a single WebSocket connection and cannot process each port in parallel.
175    let mut loops = Vec::with_capacity(ports.len() + 2);
176    // Channel to communicate with the main loop
177    let (sender, receiver) = mpsc::channel::<Message>(1);
178    for (i, (r, w)) in duplexes.into_iter().map(tokio::io::split).enumerate() {
179        writers.push(w);
180        // Each port uses 2 channels. Duplex data channel and error.
181        let ch = 2 * (i as u8);
182        loops.push(to_pod_loop(ch, r, sender.clone()).boxed());
183    }
184
185    let (ws_sink, ws_stream) = stream.split();
186    loops.push(from_pod_loop(ws_stream, sender).boxed());
187    loops.push(forwarder_loop(&ports, receiver, ws_sink, writers, error_senders).boxed());
188
189    future::try_join_all(loops).await.map(|_| ())
190}
191
192async fn to_pod_loop(
193    ch: u8,
194    reader: tokio::io::ReadHalf<DuplexStream>,
195    mut sender: mpsc::Sender<Message>,
196) -> Result<(), Error> {
197    let mut read_stream = ReaderStream::new(reader);
198    while let Some(bytes) = read_stream
199        .next()
200        .await
201        .transpose()
202        .map_err(Error::ReadBytesToSend)?
203    {
204        if !bytes.is_empty() {
205            sender
206                .send(Message::ToPod(ch, bytes))
207                .await
208                .map_err(Error::ForwardToPod)?;
209        }
210    }
211    sender
212        .send(Message::ToPodClose(ch))
213        .await
214        .map_err(Error::ForwardToPod)?;
215    Ok(())
216}
217
218async fn from_pod_loop<S>(
219    mut ws_stream: futures::stream::SplitStream<WebSocketStream<S>>,
220    mut sender: mpsc::Sender<Message>,
221) -> Result<(), Error>
222where
223    S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
224{
225    while let Some(msg) = ws_stream
226        .next()
227        .await
228        .transpose()
229        .map_err(Error::ReceiveWebSocketMessage)?
230    {
231        match msg {
232            ws::Message::Binary(mut bytes) if bytes.len() > 1 => {
233                let ch = bytes.split_to(1)[0];
234                sender
235                    .send(Message::FromPod(ch, bytes))
236                    .await
237                    .map_err(Error::ForwardFromPod)?;
238            }
239            message if message.is_close() => {
240                sender
241                    .send(Message::FromPodClose)
242                    .await
243                    .map_err(Error::ForwardFromPod)?;
244                break;
245            }
246            // REVIEW should we error on unexpected websocket message?
247            _ => {}
248        }
249    }
250    Ok(())
251}
252
253// Start a loop to handle messages received from other futures.
254// On `Message::ToPod(ch, bytes)`, a WebSocket message is sent with the channel prefix.
255// On `Message::FromPod(ch, bytes)` with an even `ch`, `bytes` are written to the port's sink.
256// On `Message::FromPod(ch, bytes)` with an odd `ch`, an error message is sent to the error channel of the port.
257async fn forwarder_loop<S>(
258    ports: &[u16],
259    mut receiver: mpsc::Receiver<Message>,
260    mut ws_sink: futures::stream::SplitSink<WebSocketStream<S>, ws::Message>,
261    mut writers: Vec<tokio::io::WriteHalf<DuplexStream>>,
262    mut error_senders: Vec<Option<ErrorSender>>,
263) -> Result<(), Error>
264where
265    S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
266{
267    #[derive(Default, Clone)]
268    struct ChannelState {
269        // Keep track if the channel has received the initialization frame.
270        initialized: bool,
271        // Keep track if the channel has shutdown.
272        shutdown: bool,
273    }
274    let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
275    let mut closed_ports = 0;
276    let mut socket_shutdown = false;
277    while let Some(msg) = receiver.next().await {
278        match msg {
279            Message::FromPod(ch, mut bytes) => {
280                let ch = ch as usize;
281                let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
282
283                let port_index = ch / 2;
284                // Initialization
285                if !channel.initialized {
286                    // The initial message must be 3 bytes including the channel prefix.
287                    if bytes.len() != 2 {
288                        return Err(Error::InvalidInitialFrameSize);
289                    }
290
291                    let port = bytes.get_u16_le();
292                    if port != ports[port_index] {
293                        return Err(Error::InvalidPortMapping {
294                            actual: port,
295                            expected: ports[port_index],
296                        });
297                    }
298
299                    channel.initialized = true;
300                    continue;
301                }
302
303                // Odd channels are for errors for (n - 1)/2 th port
304                if !ch.is_multiple_of(2) {
305                    // A port sends at most one error message because it's considered unusable after this.
306                    if let Some(sender) = error_senders[port_index].take() {
307                        let s = String::from_utf8(bytes.into_iter().collect())
308                            .map_err(Error::InvalidErrorMessage)?;
309                        sender.send(s).map_err(Error::ForwardErrorMessage)?;
310                    }
311                } else if !channel.shutdown {
312                    writers[port_index]
313                        .write_all(&bytes)
314                        .await
315                        .map_err(Error::WriteBytesFromPod)?;
316                }
317            }
318
319            Message::ToPod(ch, bytes) => {
320                let mut bin = Vec::with_capacity(bytes.len() + 1);
321                bin.push(ch);
322                bin.extend(bytes);
323                ws_sink
324                    .send(ws::Message::binary(bin))
325                    .await
326                    .map_err(Error::SendWebSocketMessage)?;
327            }
328            Message::ToPodClose(ch) => {
329                let ch = ch as usize;
330                let channel = chan_state.get_mut(ch).ok_or_else(|| Error::InvalidChannel(ch))?;
331                let port_index = ch / 2;
332
333                if !channel.shutdown {
334                    writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
335                    channel.shutdown = true;
336
337                    closed_ports += 1;
338                }
339            }
340            Message::FromPodClose => {
341                for writer in &mut writers {
342                    writer.shutdown().await.map_err(Error::Shutdown)?;
343                }
344            }
345        }
346
347        if closed_ports == ports.len() && !socket_shutdown {
348            ws_sink
349                .send(ws::Message::Close(None))
350                .await
351                .map_err(Error::SendWebSocketMessage)?;
352            socket_shutdown = true;
353        }
354    }
355    Ok(())
356}