use std::time::Duration;
use bytes::Bytes;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Status;
use futures::{
FutureExt, SinkExt, StreamExt,
channel::{mpsc, oneshot},
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream},
select, time,
};
use tokio_tungstenite::tungstenite as ws;
use crate::client::Connection;
use super::AttachParams;
type StatusReceiver = oneshot::Receiver<Status>;
type StatusSender = oneshot::Sender<Status>;
type TerminalSizeReceiver = mpsc::Receiver<TerminalSize>;
type TerminalSizeSender = mpsc::Sender<TerminalSize>;
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
#[serde(rename_all = "PascalCase")]
pub struct TerminalSize {
pub width: u16,
pub height: u16,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("failed to read from stdin: {0}")]
ReadStdin(#[source] std::io::Error),
#[error("failed to send a stdin data: {0}")]
SendStdin(#[source] ws::Error),
#[error("failed to write to stdout: {0}")]
WriteStdout(#[source] std::io::Error),
#[error("failed to write to stderr: {0}")]
WriteStderr(#[source] std::io::Error),
#[error("failed to receive a WebSocket message: {0}")]
ReceiveWebSocketMessage(#[source] ws::Error),
#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),
#[error("failed to send a WebSocket close message: {0}")]
SendClose(#[source] ws::Error),
#[error("failed to send a WebSocket ping message: {0}")]
SendPing(#[source] ws::Error),
#[error("failed to deserialize status object: {0}")]
DeserializeStatus(#[source] serde_json::Error),
#[error("failed to send status object")]
SendStatus,
#[error("failed to serialize TerminalSize object: {0}")]
SerializeTerminalSize(#[source] serde_json::Error),
#[error("failed to send terminal size message")]
SendTerminalSize(#[source] ws::Error),
#[error("failed to set terminal size, tty need to be true to resize the terminal")]
TtyNeedToBeTrue,
}
const MAX_BUF_SIZE: usize = 1024;
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct AttachedProcess {
has_stdin: bool,
has_stdout: bool,
has_stderr: bool,
stdin_writer: Option<DuplexStream>,
stdout_reader: Option<DuplexStream>,
stderr_reader: Option<DuplexStream>,
status_rx: Option<StatusReceiver>,
terminal_resize_tx: Option<TerminalSizeSender>,
task: tokio::task::JoinHandle<Result<(), Error>>,
}
impl AttachedProcess {
pub(crate) fn new(connection: Connection, ap: &AttachParams) -> Self {
let (stdin_writer, stdin_reader) = tokio::io::duplex(ap.max_stdin_buf_size.unwrap_or(MAX_BUF_SIZE));
let (stdout_writer, stdout_reader) = if ap.stdout {
let (w, r) = tokio::io::duplex(ap.max_stdout_buf_size.unwrap_or(MAX_BUF_SIZE));
(Some(w), Some(r))
} else {
(None, None)
};
let (stderr_writer, stderr_reader) = if ap.stderr {
let (w, r) = tokio::io::duplex(ap.max_stderr_buf_size.unwrap_or(MAX_BUF_SIZE));
(Some(w), Some(r))
} else {
(None, None)
};
let (status_tx, status_rx) = oneshot::channel();
let (terminal_resize_tx, terminal_resize_rx) = if ap.tty {
let (w, r) = mpsc::channel(10);
(Some(w), Some(r))
} else {
(None, None)
};
let task = tokio::spawn(start_message_loop(
connection,
stdin_reader,
stdout_writer,
stderr_writer,
status_tx,
terminal_resize_rx,
));
AttachedProcess {
has_stdin: ap.stdin,
has_stdout: ap.stdout,
has_stderr: ap.stderr,
task,
stdin_writer: Some(stdin_writer),
stdout_reader,
stderr_reader,
terminal_resize_tx,
status_rx: Some(status_rx),
}
}
pub fn stdin(&mut self) -> Option<impl AsyncWrite + Unpin + use<>> {
if !self.has_stdin {
return None;
}
self.stdin_writer.take()
}
pub fn stdout(&mut self) -> Option<impl AsyncRead + Unpin + use<>> {
if !self.has_stdout {
return None;
}
self.stdout_reader.take()
}
pub fn stderr(&mut self) -> Option<impl AsyncRead + Unpin + use<>> {
if !self.has_stderr {
return None;
}
self.stderr_reader.take()
}
#[inline]
pub fn abort(&self) {
self.task.abort();
}
pub async fn join(self) -> Result<(), Error> {
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
pub fn take_status(&mut self) -> Option<impl Future<Output = Option<Status>> + use<>> {
self.status_rx.take().map(|recv| recv.map(|res| res.ok()))
}
pub fn terminal_size(&mut self) -> Option<TerminalSizeSender> {
self.terminal_resize_tx.take()
}
}
const STDIN_CHANNEL: u8 = 0;
const STDOUT_CHANNEL: u8 = 1;
const STDERR_CHANNEL: u8 = 2;
const STATUS_CHANNEL: u8 = 3;
const RESIZE_CHANNEL: u8 = 4;
const CLOSE_CHANNEL: u8 = 255;
async fn start_message_loop(
connection: Connection,
stdin: impl AsyncRead + Unpin,
mut stdout: Option<impl AsyncWrite + Unpin>,
mut stderr: Option<impl AsyncWrite + Unpin>,
status_tx: StatusSender,
mut terminal_size_rx: Option<TerminalSizeReceiver>,
) -> Result<(), Error> {
let supports_stream_close = connection.supports_stream_close();
let stream = connection.into_stream();
let mut stdin_stream = tokio_util::io::ReaderStream::new(stdin);
let (mut server_send, raw_server_recv) = stream.split();
let mut server_recv = raw_server_recv.filter_map(filter_message).boxed();
let mut have_terminal_size_rx = terminal_size_rx.is_some();
let mut stdin_is_open = true;
let mut ping_interval = time::interval(Duration::from_secs(60));
ping_interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
ping_interval.reset();
loop {
let terminal_size_next = async {
match terminal_size_rx.as_mut() {
Some(tmp) => Some(tmp.next().await),
None => None,
}
};
select! {
_ = ping_interval.tick() => {
server_send
.send(ws::Message::Ping(Bytes::new()))
.await
.map_err(Error::SendPing)?;
},
server_message = server_recv.next() => {
match server_message {
Some(Ok(Message::Stdout(bin))) => {
if let Some(stdout) = stdout.as_mut() {
stdout.write_all(&bin[1..]).await.map_err(Error::WriteStdout)?;
}
},
Some(Ok(Message::Stderr(bin))) => {
if let Some(stderr) = stderr.as_mut() {
stderr.write_all(&bin[1..]).await.map_err(Error::WriteStderr)?;
}
},
Some(Ok(Message::Status(bin))) => {
let status = serde_json::from_slice::<Status>(&bin[1..]).map_err(Error::DeserializeStatus)?;
status_tx.send(status).map_err(|_| Error::SendStatus)?;
break
},
Some(Err(err)) => {
return Err(Error::ReceiveWebSocketMessage(err));
},
None => {
break
},
}
},
stdin_message = stdin_stream.next(), if stdin_is_open => {
match stdin_message {
Some(Ok(bytes)) => {
if !bytes.is_empty() {
let mut vec = Vec::with_capacity(bytes.len() + 1);
vec.push(STDIN_CHANNEL);
vec.extend_from_slice(&bytes[..]);
server_send
.send(ws::Message::binary(vec))
.await
.map_err(Error::SendStdin)?;
}
},
Some(Err(err)) => {
return Err(Error::ReadStdin(err));
}
None => {
if supports_stream_close {
let vec = vec![CLOSE_CHANNEL, STDIN_CHANNEL];
server_send
.send(ws::Message::binary(vec))
.await
.map_err(Error::SendStdin)?;
} else {
server_send.close().await.map_err(Error::SendClose)?;
}
stdin_is_open = false;
}
}
},
Some(terminal_size_message) = terminal_size_next, if have_terminal_size_rx => {
match terminal_size_message {
Some(new_size) => {
let new_size = serde_json::to_vec(&new_size).map_err(Error::SerializeTerminalSize)?;
let mut vec = Vec::with_capacity(new_size.len() + 1);
vec.push(RESIZE_CHANNEL);
vec.extend_from_slice(&new_size[..]);
server_send.send(ws::Message::Binary(vec.into())).await.map_err(Error::SendTerminalSize)?;
},
None => {
have_terminal_size_rx = false;
}
}
},
}
}
Ok(())
}
enum Message {
Stdout(Vec<u8>),
Stderr(Vec<u8>),
Status(Vec<u8>),
}
async fn filter_message(wsm: Result<ws::Message, ws::Error>) -> Option<Result<Message, ws::Error>> {
match wsm {
Ok(ws::Message::Binary(bin)) if bin.len() > 1 => match bin[0] {
STDOUT_CHANNEL => Some(Ok(Message::Stdout(bin.into()))),
STDERR_CHANNEL => Some(Ok(Message::Stderr(bin.into()))),
STATUS_CHANNEL => Some(Ok(Message::Status(bin.into()))),
_ => None,
},
Ok(_) => None,
Err(err) => Some(Err(err)),
}
}