use axum::{
Json,
extract::{
Path, Query,
ws::{Message, WebSocket, WebSocketUpgrade},
},
http::StatusCode,
response::IntoResponse,
};
use futures::StreamExt;
use manta_backend_dispatcher::{
interfaces::console::{ConsoleAttachment, ConsoleTrait, TermSize},
types::{K8sAuth, K8sDetails},
};
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Sender;
use super::{
ErrorResponse, RequestCtx, SiteHeader, require_k8s_url, require_vault,
to_handler_error,
};
use crate::service;
pub use manta_shared::types::api::queries::ConsoleQuery;
#[utoipa::path(get, path = "/nodes/{xname}/console", tag = "console",
params(("xname" = String, Path, description = "Node xname"), ConsoleQuery, SiteHeader),
security(("bearerAuth" = [])),
responses(
(status = 101, description = "WebSocket upgrade"),
(status = 401, description = "Unauthorized", body = ErrorResponse),
(status = 500, description = "Internal error", body = ErrorResponse),
(status = 501, description = "Vault or k8s not configured", body = ErrorResponse),
)
)]
#[tracing::instrument(skip_all, fields(xname = %xname))]
pub async fn console_node_ws(
ctx: RequestCtx,
Path(xname): Path<String>,
Query(q): Query<ConsoleQuery>,
ws: WebSocketUpgrade,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let (k8s_api_url, vault_base_url, timeout) = {
let infra = ctx.infra();
let k = require_k8s_url(infra.k8s_api_url)?.to_string();
let v = require_vault(infra.vault_base_url)?.to_string();
service::authorization::validate_user_group_members_access(
&infra,
&ctx.token,
std::slice::from_ref(&xname),
)
.await
.map_err(to_handler_error)?;
(k, v, ctx.state.console_inactivity_timeout)
};
let k8s = K8sDetails {
api_url: k8s_api_url,
authentication: K8sAuth::Vault {
base_url: vault_base_url,
},
};
let RequestCtx {
state,
token,
site_name,
} = ctx;
Ok(ws.on_upgrade(move |socket| async move {
tracing::info!("WebSocket console opened for node {xname}");
if let Some(site) = state.sites.get(&site_name) {
match site
.backend
.attach_to_node_console(
&token,
&site_name,
&xname,
TermSize {
width: q.cols,
height: q.rows,
},
&k8s,
)
.await
{
Ok(ConsoleAttachment {
stdin,
stdout,
resize,
}) => {
run_console_bridge(socket, stdin, stdout, resize, timeout).await;
tracing::info!("WebSocket console closed for node {xname}");
}
Err(e) => {
tracing::error!("Failed to attach to node console {xname}: {e:#}");
}
}
}
}))
}
#[utoipa::path(get, path = "/sessions/{name}/console", tag = "console",
params(("name" = String, Path, description = "Session name"), ConsoleQuery, SiteHeader),
security(("bearerAuth" = [])),
responses(
(status = 101, description = "WebSocket upgrade"),
(status = 401, description = "Unauthorized", body = ErrorResponse),
(status = 500, description = "Internal error", body = ErrorResponse),
(status = 501, description = "Vault or k8s not configured", body = ErrorResponse),
)
)]
#[tracing::instrument(skip_all, fields(session = %name))]
pub async fn console_session_ws(
ctx: RequestCtx,
Path(name): Path<String>,
Query(q): Query<ConsoleQuery>,
ws: WebSocketUpgrade,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let (k8s_api_url, vault_base_url, timeout) = {
let infra = ctx.infra();
let k = require_k8s_url(infra.k8s_api_url)?.to_string();
let v = require_vault(infra.vault_base_url)?.to_string();
service::session::validate_session_access(&infra, &ctx.token, &name)
.await
.map_err(to_handler_error)?;
service::session::validate_console_session(&infra, &ctx.token, &name)
.await
.map_err(to_handler_error)?;
(k, v, ctx.state.console_inactivity_timeout)
};
let k8s = K8sDetails {
api_url: k8s_api_url,
authentication: K8sAuth::Vault {
base_url: vault_base_url,
},
};
let RequestCtx {
state,
token,
site_name,
} = ctx;
Ok(ws.on_upgrade(move |socket| async move {
tracing::info!("WebSocket console opened for session {name}");
if let Some(site) = state.sites.get(&site_name) {
match site
.backend
.attach_to_session_console(
&token,
&site_name,
&name,
TermSize {
width: q.cols,
height: q.rows,
},
&k8s,
)
.await
{
Ok(ConsoleAttachment {
stdin,
stdout,
resize,
}) => {
run_console_bridge(socket, stdin, stdout, resize, timeout).await;
tracing::info!("WebSocket console closed for session {name}");
}
Err(e) => {
tracing::error!("Failed to attach to session console {name}: {e:#}");
}
}
}
}))
}
#[allow(async_fn_in_trait)]
trait ConsoleSocket: Send + Unpin {
async fn recv(&mut self) -> Option<Result<Message, axum::Error>>;
async fn send(&mut self, msg: Message) -> Result<(), axum::Error>;
}
impl ConsoleSocket for WebSocket {
async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
WebSocket::recv(self).await
}
async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
WebSocket::send(self, msg).await
}
}
async fn run_console_bridge<S: ConsoleSocket>(
mut socket: S,
mut console_in: Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
console_out: Box<dyn tokio::io::AsyncRead + Unpin + Send>,
resize: Sender<TermSize>,
inactivity_timeout: std::time::Duration,
) {
let mut out_stream = tokio_util::io::ReaderStream::new(console_out);
let mut deadline = tokio::time::Instant::now() + inactivity_timeout;
loop {
tokio::select! {
msg = socket.recv() => {
match msg {
Some(Ok(Message::Binary(data))) => {
deadline = tokio::time::Instant::now() + inactivity_timeout;
if console_in.write_all(&data).await.is_err() { break; }
}
Some(Ok(Message::Text(text))) => {
deadline = tokio::time::Instant::now() + inactivity_timeout;
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text)
&& v.get("type").and_then(|t| t.as_str()) == Some("resize")
{
let cols = v.get("cols").and_then(|c| c.as_u64()).unwrap_or(0);
let rows = v.get("rows").and_then(|r| r.as_u64()).unwrap_or(0);
if cols > 0 && rows > 0 && cols <= u64::from(u16::MAX) && rows <= u64::from(u16::MAX) {
#[allow(clippy::cast_possible_truncation)]
let size = TermSize {
width: cols as u16,
height: rows as u16,
};
let _ = resize.try_send(size);
}
continue;
}
if console_in.write_all(text.as_bytes()).await.is_err() { break; }
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(_)) => {} Some(Err(_)) => break,
}
}
chunk = out_stream.next() => {
match chunk {
Some(Ok(data)) => {
if socket.send(Message::Binary(data)).await.is_err() { break; }
}
Some(Err(_)) | None => break,
}
}
() = tokio::time::sleep_until(deadline) => {
tracing::warn!("Console session idle for {:?}, closing", inactivity_timeout);
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::mpsc;
struct CaptureWriter(Arc<Mutex<Vec<u8>>>);
impl tokio::io::AsyncWrite for CaptureWriter {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.0.lock().unwrap().extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct PendingReader;
impl tokio::io::AsyncRead for PendingReader {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct MockSocket {
rx: mpsc::UnboundedReceiver<Result<Message, axum::Error>>,
tx: mpsc::UnboundedSender<Message>,
}
impl ConsoleSocket for MockSocket {
async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
self.rx.recv().await
}
async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
self.tx.send(msg).map_err(axum::Error::new)
}
}
fn new_mock_socket() -> (
MockSocket,
mpsc::UnboundedSender<Result<Message, axum::Error>>,
mpsc::UnboundedReceiver<Message>,
) {
let (in_tx, in_rx) = mpsc::unbounded_channel();
let (out_tx, out_rx) = mpsc::unbounded_channel();
(
MockSocket {
rx: in_rx,
tx: out_tx,
},
in_tx,
out_rx,
)
}
async fn bridge_exited_within(
handle: &mut tokio::task::JoinHandle<()>,
cap: Duration,
) -> bool {
tokio::select! {
_ = handle => true,
() = tokio::time::sleep(cap) => false,
}
}
#[tokio::test(start_paused = true)]
async fn inactivity_timeout_fires_when_no_traffic() {
let (socket, _in_tx, _out_rx) = new_mock_socket();
let console_in = Box::new(tokio::io::sink());
let console_out = Box::new(PendingReader);
let (resize_tx, _resize_rx) = mpsc::channel(8);
let mut handle = tokio::spawn(async move {
run_console_bridge(
socket,
console_in,
console_out,
resize_tx,
Duration::from_secs(60),
)
.await;
});
assert!(
!bridge_exited_within(&mut handle, Duration::from_secs(59)).await,
"bridge exited before the 60s inactivity timeout"
);
assert!(
bridge_exited_within(&mut handle, Duration::from_secs(5)).await,
"bridge did not exit after the inactivity timeout"
);
}
#[tokio::test(start_paused = true)]
async fn client_binary_message_resets_deadline() {
let (socket, in_tx, _out_rx) = new_mock_socket();
let console_in = Box::new(tokio::io::sink());
let console_out = Box::new(PendingReader);
let (resize_tx, _resize_rx) = mpsc::channel(8);
let mut handle = tokio::spawn(async move {
run_console_bridge(
socket,
console_in,
console_out,
resize_tx,
Duration::from_secs(60),
)
.await;
});
tokio::time::sleep(Duration::from_secs(59)).await;
in_tx
.send(Ok(Message::Binary(b"hi".to_vec().into())))
.unwrap();
tokio::task::yield_now().await;
assert!(
!bridge_exited_within(&mut handle, Duration::from_secs(31)).await,
"deadline was not reset by client binary message"
);
assert!(
bridge_exited_within(&mut handle, Duration::from_secs(35)).await,
"bridge did not exit after the reset deadline"
);
}
#[tokio::test(start_paused = true)]
async fn resize_text_forwards_to_resize_channel_and_resets_deadline() {
let (socket, in_tx, _out_rx) = new_mock_socket();
let written: Arc<Mutex<Vec<u8>>> = Default::default();
let console_in = Box::new(CaptureWriter(written.clone()));
let console_out = Box::new(PendingReader);
let (resize_tx, mut resize_rx) = mpsc::channel(8);
let mut handle = tokio::spawn(async move {
run_console_bridge(
socket,
console_in,
console_out,
resize_tx,
Duration::from_secs(60),
)
.await;
});
tokio::time::sleep(Duration::from_secs(59)).await;
in_tx
.send(Ok(Message::Text(
r#"{"type":"resize","cols":120,"rows":40}"#.into(),
)))
.unwrap();
tokio::task::yield_now().await;
assert!(
!bridge_exited_within(&mut handle, Duration::from_secs(30)).await,
"deadline was not reset by resize message"
);
assert!(
written.lock().unwrap().is_empty(),
"resize text frame was forwarded to console stdin (should be parsed)"
);
let size = resize_rx.try_recv().expect(
"resize message should have been forwarded to the resize channel",
);
assert_eq!(size.width, 120);
assert_eq!(size.height, 40);
handle.abort();
}
#[tokio::test(start_paused = true)]
async fn client_close_exits_loop_immediately() {
let (socket, in_tx, _out_rx) = new_mock_socket();
let console_in = Box::new(tokio::io::sink());
let console_out = Box::new(PendingReader);
let (resize_tx, _resize_rx) = mpsc::channel(8);
let mut handle = tokio::spawn(async move {
run_console_bridge(
socket,
console_in,
console_out,
resize_tx,
Duration::from_secs(3600),
)
.await;
});
in_tx.send(Ok(Message::Close(None))).unwrap();
assert!(
bridge_exited_within(&mut handle, Duration::from_secs(1)).await,
"bridge did not exit on Close frame"
);
}
#[tokio::test(start_paused = true)]
async fn server_to_client_data_does_not_reset_deadline() {
use tokio::io::AsyncReadExt;
let (socket, _in_tx, mut out_rx) = new_mock_socket();
let console_in = Box::new(tokio::io::sink());
let console_out =
Box::new(std::io::Cursor::new(b"chunk".to_vec()).chain(PendingReader));
let (resize_tx, _resize_rx) = mpsc::channel(8);
let mut handle = tokio::spawn(async move {
run_console_bridge(
socket,
console_in,
console_out,
resize_tx,
Duration::from_secs(60),
)
.await;
});
tokio::spawn(async move { while out_rx.recv().await.is_some() {} });
assert!(
bridge_exited_within(&mut handle, Duration::from_secs(65)).await,
"server-to-client data should NOT keep the deadline alive"
);
}
}