1use k8s_openapi::apimachinery::pkg::apis::meta::v1::Status;
2
3use futures::{
4 channel::{mpsc, oneshot},
5 FutureExt, SinkExt, StreamExt,
6};
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9use tokio::{
10 io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream},
11 select,
12};
13use tokio_tungstenite::tungstenite as ws;
14
15use crate::client::Connection;
16
17use super::AttachParams;
18
19type StatusReceiver = oneshot::Receiver<Status>;
20type StatusSender = oneshot::Sender<Status>;
21
22type TerminalSizeReceiver = mpsc::Receiver<TerminalSize>;
23type TerminalSizeSender = mpsc::Sender<TerminalSize>;
24
25#[derive(Debug, Serialize, Deserialize)]
27#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
28#[serde(rename_all = "PascalCase")]
29pub struct TerminalSize {
30 pub width: u16,
32 pub height: u16,
34}
35
36#[derive(Debug, Error)]
38pub enum Error {
39 #[error("failed to read from stdin: {0}")]
41 ReadStdin(#[source] std::io::Error),
42
43 #[error("failed to send a stdin data: {0}")]
45 SendStdin(#[source] ws::Error),
46
47 #[error("failed to write to stdout: {0}")]
49 WriteStdout(#[source] std::io::Error),
50
51 #[error("failed to write to stderr: {0}")]
53 WriteStderr(#[source] std::io::Error),
54
55 #[error("failed to receive a WebSocket message: {0}")]
57 ReceiveWebSocketMessage(#[source] ws::Error),
58
59 #[error("failed to complete the background task: {0}")]
61 Spawn(#[source] tokio::task::JoinError),
62
63 #[error("failed to send a WebSocket close message: {0}")]
65 SendClose(#[source] ws::Error),
66
67 #[error("failed to deserialize status object: {0}")]
69 DeserializeStatus(#[source] serde_json::Error),
70
71 #[error("failed to send status object")]
73 SendStatus,
74
75 #[error("failed to serialize TerminalSize object: {0}")]
77 SerializeTerminalSize(#[source] serde_json::Error),
78
79 #[error("failed to send terminal size message")]
81 SendTerminalSize(#[source] ws::Error),
82
83 #[error("failed to set terminal size, tty need to be true to resize the terminal")]
85 TtyNeedToBeTrue,
86}
87
88const MAX_BUF_SIZE: usize = 1024;
89
90#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
99pub struct AttachedProcess {
100 has_stdin: bool,
101 has_stdout: bool,
102 has_stderr: bool,
103 stdin_writer: Option<DuplexStream>,
104 stdout_reader: Option<DuplexStream>,
105 stderr_reader: Option<DuplexStream>,
106 status_rx: Option<StatusReceiver>,
107 terminal_resize_tx: Option<TerminalSizeSender>,
108 task: tokio::task::JoinHandle<Result<(), Error>>,
109}
110
111impl AttachedProcess {
112 pub(crate) fn new(connection: Connection, ap: &AttachParams) -> Self {
113 let (stdin_writer, stdin_reader) = tokio::io::duplex(ap.max_stdin_buf_size.unwrap_or(MAX_BUF_SIZE));
116 let (stdout_writer, stdout_reader) = if ap.stdout {
117 let (w, r) = tokio::io::duplex(ap.max_stdout_buf_size.unwrap_or(MAX_BUF_SIZE));
118 (Some(w), Some(r))
119 } else {
120 (None, None)
121 };
122 let (stderr_writer, stderr_reader) = if ap.stderr {
123 let (w, r) = tokio::io::duplex(ap.max_stderr_buf_size.unwrap_or(MAX_BUF_SIZE));
124 (Some(w), Some(r))
125 } else {
126 (None, None)
127 };
128 let (status_tx, status_rx) = oneshot::channel();
129 let (terminal_resize_tx, terminal_resize_rx) = if ap.tty {
130 let (w, r) = mpsc::channel(10);
131 (Some(w), Some(r))
132 } else {
133 (None, None)
134 };
135
136 let task = tokio::spawn(start_message_loop(
137 connection,
138 stdin_reader,
139 stdout_writer,
140 stderr_writer,
141 status_tx,
142 terminal_resize_rx,
143 ));
144
145 AttachedProcess {
146 has_stdin: ap.stdin,
147 has_stdout: ap.stdout,
148 has_stderr: ap.stderr,
149 task,
150 stdin_writer: Some(stdin_writer),
151 stdout_reader,
152 stderr_reader,
153 terminal_resize_tx,
154 status_rx: Some(status_rx),
155 }
156 }
157
158 pub fn stdin(&mut self) -> Option<impl AsyncWrite + Unpin + use<>> {
171 if !self.has_stdin {
172 return None;
173 }
174 self.stdin_writer.take()
175 }
176
177 pub fn stdout(&mut self) -> Option<impl AsyncRead + Unpin + use<>> {
191 if !self.has_stdout {
192 return None;
193 }
194 self.stdout_reader.take()
195 }
196
197 pub fn stderr(&mut self) -> Option<impl AsyncRead + Unpin + use<>> {
211 if !self.has_stderr {
212 return None;
213 }
214 self.stderr_reader.take()
215 }
216
217 #[inline]
219 pub fn abort(&self) {
220 self.task.abort();
221 }
222
223 pub async fn join(self) -> Result<(), Error> {
225 self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
226 }
227
228 pub fn take_status(&mut self) -> Option<impl Future<Output = Option<Status>> + use<>> {
232 self.status_rx.take().map(|recv| recv.map(|res| res.ok()))
233 }
234
235 pub fn terminal_size(&mut self) -> Option<TerminalSizeSender> {
252 self.terminal_resize_tx.take()
253 }
254}
255
256const STDIN_CHANNEL: u8 = 0;
258const STDOUT_CHANNEL: u8 = 1;
259const STDERR_CHANNEL: u8 = 2;
260const STATUS_CHANNEL: u8 = 3;
262const RESIZE_CHANNEL: u8 = 4;
264const CLOSE_CHANNEL: u8 = 255;
266
267async fn start_message_loop(
268 connection: Connection,
269 stdin: impl AsyncRead + Unpin,
270 mut stdout: Option<impl AsyncWrite + Unpin>,
271 mut stderr: Option<impl AsyncWrite + Unpin>,
272 status_tx: StatusSender,
273 mut terminal_size_rx: Option<TerminalSizeReceiver>,
274) -> Result<(), Error> {
275 let supports_stream_close = connection.supports_stream_close();
276 let stream = connection.into_stream();
277 let mut stdin_stream = tokio_util::io::ReaderStream::new(stdin);
278 let (mut server_send, raw_server_recv) = stream.split();
279 let mut server_recv = raw_server_recv.filter_map(filter_message).boxed();
281 let mut have_terminal_size_rx = terminal_size_rx.is_some();
282
283 let mut stdin_is_open = true;
285
286 loop {
287 let terminal_size_next = async {
288 match terminal_size_rx.as_mut() {
289 Some(tmp) => Some(tmp.next().await),
290 None => None,
291 }
292 };
293 select! {
294 server_message = server_recv.next() => {
295 match server_message {
296 Some(Ok(Message::Stdout(bin))) => {
297 if let Some(stdout) = stdout.as_mut() {
298 stdout.write_all(&bin[1..]).await.map_err(Error::WriteStdout)?;
299 }
300 },
301 Some(Ok(Message::Stderr(bin))) => {
302 if let Some(stderr) = stderr.as_mut() {
303 stderr.write_all(&bin[1..]).await.map_err(Error::WriteStderr)?;
304 }
305 },
306 Some(Ok(Message::Status(bin))) => {
307 let status = serde_json::from_slice::<Status>(&bin[1..]).map_err(Error::DeserializeStatus)?;
308 status_tx.send(status).map_err(|_| Error::SendStatus)?;
309 break
310 },
311 Some(Err(err)) => {
312 return Err(Error::ReceiveWebSocketMessage(err));
313 },
314 None => {
315 break
317 },
318 }
319 },
320 stdin_message = stdin_stream.next(), if stdin_is_open => {
321 match stdin_message {
322 Some(Ok(bytes)) => {
323 if !bytes.is_empty() {
324 let mut vec = Vec::with_capacity(bytes.len() + 1);
325 vec.push(STDIN_CHANNEL);
326 vec.extend_from_slice(&bytes[..]);
327 server_send
328 .send(ws::Message::binary(vec))
329 .await
330 .map_err(Error::SendStdin)?;
331 }
332 },
333 Some(Err(err)) => {
334 return Err(Error::ReadStdin(err));
335 }
336 None => {
337 if supports_stream_close {
340 let vec = vec![CLOSE_CHANNEL, STDIN_CHANNEL];
343 server_send
344 .send(ws::Message::binary(vec))
345 .await
346 .map_err(Error::SendStdin)?;
347 } else {
348 server_send.close().await.map_err(Error::SendClose)?;
352 }
353
354 stdin_is_open = false;
356 }
357 }
358 },
359 Some(terminal_size_message) = terminal_size_next, if have_terminal_size_rx => {
360 match terminal_size_message {
361 Some(new_size) => {
362 let new_size = serde_json::to_vec(&new_size).map_err(Error::SerializeTerminalSize)?;
363 let mut vec = Vec::with_capacity(new_size.len() + 1);
364 vec.push(RESIZE_CHANNEL);
365 vec.extend_from_slice(&new_size[..]);
366 server_send.send(ws::Message::Binary(vec.into())).await.map_err(Error::SendTerminalSize)?;
367 },
368 None => {
369 have_terminal_size_rx = false;
370 }
371 }
372 },
373 }
374 }
375
376 Ok(())
377}
378
379enum Message {
381 Stdout(Vec<u8>),
383 Stderr(Vec<u8>),
385 Status(Vec<u8>),
387}
388
389async fn filter_message(wsm: Result<ws::Message, ws::Error>) -> Option<Result<Message, ws::Error>> {
391 match wsm {
392 Ok(ws::Message::Binary(bin)) if bin.len() > 1 => match bin[0] {
395 STDOUT_CHANNEL => Some(Ok(Message::Stdout(bin.into()))),
396 STDERR_CHANNEL => Some(Ok(Message::Stderr(bin.into()))),
397 STATUS_CHANNEL => Some(Ok(Message::Status(bin.into()))),
398 _ => None,
400 },
401 Ok(_) => None,
405 Err(err) => Some(Err(err)),
408 }
409}