1use std::time::Duration;
2
3use bytes::Bytes;
4use k8s_openapi::apimachinery::pkg::apis::meta::v1::Status;
5
6use futures::{
7 FutureExt, SinkExt, StreamExt,
8 channel::{mpsc, oneshot},
9};
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12use tokio::{
13 io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream},
14 select, time,
15};
16use tokio_tungstenite::tungstenite as ws;
17
18use crate::client::Connection;
19
20use super::AttachParams;
21
22type StatusReceiver = oneshot::Receiver<Status>;
23type StatusSender = oneshot::Sender<Status>;
24
25type TerminalSizeReceiver = mpsc::Receiver<TerminalSize>;
26type TerminalSizeSender = mpsc::Sender<TerminalSize>;
27
28#[derive(Debug, Serialize, Deserialize)]
30#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
31#[serde(rename_all = "PascalCase")]
32pub struct TerminalSize {
33 pub width: u16,
35 pub height: u16,
37}
38
39#[derive(Debug, Error)]
41pub enum Error {
42 #[error("failed to read from stdin: {0}")]
44 ReadStdin(#[source] std::io::Error),
45
46 #[error("failed to send a stdin data: {0}")]
48 SendStdin(#[source] ws::Error),
49
50 #[error("failed to write to stdout: {0}")]
52 WriteStdout(#[source] std::io::Error),
53
54 #[error("failed to write to stderr: {0}")]
56 WriteStderr(#[source] std::io::Error),
57
58 #[error("failed to receive a WebSocket message: {0}")]
60 ReceiveWebSocketMessage(#[source] ws::Error),
61
62 #[error("failed to complete the background task: {0}")]
64 Spawn(#[source] tokio::task::JoinError),
65
66 #[error("failed to send a WebSocket close message: {0}")]
68 SendClose(#[source] ws::Error),
69
70 #[error("failed to send a WebSocket ping message: {0}")]
72 SendPing(#[source] ws::Error),
73
74 #[error("failed to deserialize status object: {0}")]
76 DeserializeStatus(#[source] serde_json::Error),
77
78 #[error("failed to send status object")]
80 SendStatus,
81
82 #[error("failed to serialize TerminalSize object: {0}")]
84 SerializeTerminalSize(#[source] serde_json::Error),
85
86 #[error("failed to send terminal size message")]
88 SendTerminalSize(#[source] ws::Error),
89
90 #[error("failed to set terminal size, tty need to be true to resize the terminal")]
92 TtyNeedToBeTrue,
93}
94
95const MAX_BUF_SIZE: usize = 1024;
96
97#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
106pub struct AttachedProcess {
107 has_stdin: bool,
108 has_stdout: bool,
109 has_stderr: bool,
110 stdin_writer: Option<DuplexStream>,
111 stdout_reader: Option<DuplexStream>,
112 stderr_reader: Option<DuplexStream>,
113 status_rx: Option<StatusReceiver>,
114 terminal_resize_tx: Option<TerminalSizeSender>,
115 task: tokio::task::JoinHandle<Result<(), Error>>,
116}
117
118impl AttachedProcess {
119 pub(crate) fn new(connection: Connection, ap: &AttachParams) -> Self {
120 let (stdin_writer, stdin_reader) = tokio::io::duplex(ap.max_stdin_buf_size.unwrap_or(MAX_BUF_SIZE));
123 let (stdout_writer, stdout_reader) = if ap.stdout {
124 let (w, r) = tokio::io::duplex(ap.max_stdout_buf_size.unwrap_or(MAX_BUF_SIZE));
125 (Some(w), Some(r))
126 } else {
127 (None, None)
128 };
129 let (stderr_writer, stderr_reader) = if ap.stderr {
130 let (w, r) = tokio::io::duplex(ap.max_stderr_buf_size.unwrap_or(MAX_BUF_SIZE));
131 (Some(w), Some(r))
132 } else {
133 (None, None)
134 };
135 let (status_tx, status_rx) = oneshot::channel();
136 let (terminal_resize_tx, terminal_resize_rx) = if ap.tty {
137 let (w, r) = mpsc::channel(10);
138 (Some(w), Some(r))
139 } else {
140 (None, None)
141 };
142
143 let task = tokio::spawn(start_message_loop(
144 connection,
145 stdin_reader,
146 stdout_writer,
147 stderr_writer,
148 status_tx,
149 terminal_resize_rx,
150 ));
151
152 AttachedProcess {
153 has_stdin: ap.stdin,
154 has_stdout: ap.stdout,
155 has_stderr: ap.stderr,
156 task,
157 stdin_writer: Some(stdin_writer),
158 stdout_reader,
159 stderr_reader,
160 terminal_resize_tx,
161 status_rx: Some(status_rx),
162 }
163 }
164
165 pub fn stdin(&mut self) -> Option<impl AsyncWrite + Unpin + use<>> {
178 if !self.has_stdin {
179 return None;
180 }
181 self.stdin_writer.take()
182 }
183
184 pub fn stdout(&mut self) -> Option<impl AsyncRead + Unpin + use<>> {
198 if !self.has_stdout {
199 return None;
200 }
201 self.stdout_reader.take()
202 }
203
204 pub fn stderr(&mut self) -> Option<impl AsyncRead + Unpin + use<>> {
218 if !self.has_stderr {
219 return None;
220 }
221 self.stderr_reader.take()
222 }
223
224 #[inline]
226 pub fn abort(&self) {
227 self.task.abort();
228 }
229
230 pub async fn join(self) -> Result<(), Error> {
232 self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
233 }
234
235 pub fn take_status(&mut self) -> Option<impl Future<Output = Option<Status>> + use<>> {
239 self.status_rx.take().map(|recv| recv.map(|res| res.ok()))
240 }
241
242 pub fn terminal_size(&mut self) -> Option<TerminalSizeSender> {
259 self.terminal_resize_tx.take()
260 }
261}
262
263const STDIN_CHANNEL: u8 = 0;
265const STDOUT_CHANNEL: u8 = 1;
266const STDERR_CHANNEL: u8 = 2;
267const STATUS_CHANNEL: u8 = 3;
269const RESIZE_CHANNEL: u8 = 4;
271const CLOSE_CHANNEL: u8 = 255;
273
274async fn start_message_loop(
275 connection: Connection,
276 stdin: impl AsyncRead + Unpin,
277 mut stdout: Option<impl AsyncWrite + Unpin>,
278 mut stderr: Option<impl AsyncWrite + Unpin>,
279 status_tx: StatusSender,
280 mut terminal_size_rx: Option<TerminalSizeReceiver>,
281) -> Result<(), Error> {
282 let supports_stream_close = connection.supports_stream_close();
283 let stream = connection.into_stream();
284 let mut stdin_stream = tokio_util::io::ReaderStream::new(stdin);
285 let (mut server_send, raw_server_recv) = stream.split();
286 let mut server_recv = raw_server_recv.filter_map(filter_message).boxed();
288 let mut have_terminal_size_rx = terminal_size_rx.is_some();
289
290 let mut stdin_is_open = true;
292
293 let mut ping_interval = time::interval(Duration::from_secs(60));
294 ping_interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
295 ping_interval.reset();
296
297 loop {
298 let terminal_size_next = async {
299 match terminal_size_rx.as_mut() {
300 Some(tmp) => Some(tmp.next().await),
301 None => None,
302 }
303 };
304
305 select! {
306 _ = ping_interval.tick() => {
307 server_send
309 .send(ws::Message::Ping(Bytes::new()))
310 .await
311 .map_err(Error::SendPing)?;
312 },
313
314 server_message = server_recv.next() => {
315 match server_message {
316 Some(Ok(Message::Stdout(bin))) => {
317 if let Some(stdout) = stdout.as_mut() {
318 stdout.write_all(&bin[1..]).await.map_err(Error::WriteStdout)?;
319 }
320 },
321 Some(Ok(Message::Stderr(bin))) => {
322 if let Some(stderr) = stderr.as_mut() {
323 stderr.write_all(&bin[1..]).await.map_err(Error::WriteStderr)?;
324 }
325 },
326 Some(Ok(Message::Status(bin))) => {
327 let status = serde_json::from_slice::<Status>(&bin[1..]).map_err(Error::DeserializeStatus)?;
328 status_tx.send(status).map_err(|_| Error::SendStatus)?;
329 break
330 },
331 Some(Err(err)) => {
332 return Err(Error::ReceiveWebSocketMessage(err));
333 },
334 None => {
335 break
337 },
338 }
339 },
340 stdin_message = stdin_stream.next(), if stdin_is_open => {
341 match stdin_message {
342 Some(Ok(bytes)) => {
343 if !bytes.is_empty() {
344 let mut vec = Vec::with_capacity(bytes.len() + 1);
345 vec.push(STDIN_CHANNEL);
346 vec.extend_from_slice(&bytes[..]);
347 server_send
348 .send(ws::Message::binary(vec))
349 .await
350 .map_err(Error::SendStdin)?;
351 }
352 },
353 Some(Err(err)) => {
354 return Err(Error::ReadStdin(err));
355 }
356 None => {
357 if supports_stream_close {
360 let vec = vec![CLOSE_CHANNEL, STDIN_CHANNEL];
363 server_send
364 .send(ws::Message::binary(vec))
365 .await
366 .map_err(Error::SendStdin)?;
367 } else {
368 server_send.close().await.map_err(Error::SendClose)?;
372 }
373
374 stdin_is_open = false;
376 }
377 }
378 },
379 Some(terminal_size_message) = terminal_size_next, if have_terminal_size_rx => {
380 match terminal_size_message {
381 Some(new_size) => {
382 let new_size = serde_json::to_vec(&new_size).map_err(Error::SerializeTerminalSize)?;
383 let mut vec = Vec::with_capacity(new_size.len() + 1);
384 vec.push(RESIZE_CHANNEL);
385 vec.extend_from_slice(&new_size[..]);
386 server_send.send(ws::Message::Binary(vec.into())).await.map_err(Error::SendTerminalSize)?;
387 },
388 None => {
389 have_terminal_size_rx = false;
390 }
391 }
392 },
393 }
394 }
395
396 Ok(())
397}
398
399enum Message {
401 Stdout(Vec<u8>),
403 Stderr(Vec<u8>),
405 Status(Vec<u8>),
407}
408
409async fn filter_message(wsm: Result<ws::Message, ws::Error>) -> Option<Result<Message, ws::Error>> {
411 match wsm {
412 Ok(ws::Message::Binary(bin)) if bin.len() > 1 => match bin[0] {
415 STDOUT_CHANNEL => Some(Ok(Message::Stdout(bin.into()))),
416 STDERR_CHANNEL => Some(Ok(Message::Stderr(bin.into()))),
417 STATUS_CHANNEL => Some(Ok(Message::Status(bin.into()))),
418 _ => None,
420 },
421 Ok(_) => None,
425 Err(err) => Some(Err(err)),
428 }
429}