use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use axum::body::{Body, Bytes};
use axum::extract::{ws::WebSocketUpgrade, Path as AxumPath, State};
use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use axum::routing::{any, post};
use axum::Router;
use base64::Engine as _;
use clap::ValueEnum;
use futures_util::{SinkExt, StreamExt};
use rand::RngCore;
use serde::Deserialize;
use tokio::sync::RwLock;
use tower_http::services::ServeDir;
use crate::client::CellosClient;
use crate::config::Config;
use crate::exit::{CtlError, CtlResult};
const BUNDLE_DIR_RELATIVE: &str = "static";
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, ValueEnum)]
pub enum BindMode {
#[default]
Auto,
Loopback,
Unix,
}
#[derive(Clone)]
struct AppState {
upstream_base: Arc<String>,
upstream_bearer: Arc<Option<String>>,
session_token: Arc<String>,
session_cookie: Arc<RwLock<Option<String>>>,
exchange_consumed: Arc<std::sync::atomic::AtomicBool>,
bundle_dir: Arc<PathBuf>,
}
#[derive(Deserialize)]
struct ExchangeRequest {
sess: String,
}
pub async fn run(cfg: &Config, open: bool, bind: BindMode) -> CtlResult<()> {
let _ = CellosClient::new(cfg)?;
let upstream_base = cfg.effective_server().trim_end_matches('/').to_string();
let upstream_bearer = cfg.effective_token();
let bundle_dir = resolve_bundle_dir()?;
if !bundle_dir.join("index.html").exists() {
return Err(CtlError::usage(format!(
"webui bundle not found at {}/index.html — run `npm --prefix web run build` first",
bundle_dir.display()
)));
}
let session_token = mint_session_token();
let upstream_log = upstream_base.clone();
let state = AppState {
upstream_base: Arc::new(upstream_base),
upstream_bearer: Arc::new(upstream_bearer),
session_token: Arc::new(session_token.clone()),
session_cookie: Arc::new(RwLock::new(None)),
exchange_consumed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
bundle_dir: Arc::new(bundle_dir.clone()),
};
let app = build_router(state);
let (want_tcp, want_unix) = resolve_bind_plan(bind)?;
let tcp_listener = if want_tcp {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
Some(
tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| CtlError::usage(format!("bind 127.0.0.1: {e}")))?,
)
} else {
None
};
#[cfg(unix)]
let unix_socket_path: Option<PathBuf> = if want_unix {
Some(unix_socket_path_for_pid(std::process::id()))
} else {
None
};
#[cfg(not(unix))]
let unix_socket_path: Option<PathBuf> = None;
let browser_url = if let Some(l) = tcp_listener.as_ref() {
let local_addr = l
.local_addr()
.map_err(|e| CtlError::usage(format!("local_addr: {e}")))?;
Some(format!("http://{}/#sess={}", local_addr, session_token))
} else {
None
};
if let Some(url) = browser_url.as_ref() {
println!("cellctl webui: {}", url);
}
if let Some(p) = unix_socket_path.as_ref() {
println!("cellctl webui: unix://{}", p.display());
}
if browser_url.is_none() {
eprintln!(
"cellctl webui: --bind unix has no browser-reachable URL; \
use a forwarder (e.g. `socat TCP-LISTEN:0,reuseaddr,fork UNIX-CONNECT:{}`) \
or rerun with `--bind auto` for a loopback URL.",
unix_socket_path
.as_ref()
.map(|p| p.display().to_string())
.unwrap_or_else(|| "<socket>".to_string())
);
}
println!("upstream: {}", upstream_log);
println!("press Ctrl-C to stop");
if open {
if let Some(url) = browser_url.as_ref() {
if !is_safe_open_url(url) {
eprintln!(
"cellctl webui: refusing to open URL (failed loopback-http sanity check)"
);
} else if let Err(e) = opener::open(url) {
eprintln!("cellctl webui: could not launch browser: {e}");
}
} else {
eprintln!("cellctl webui: --open ignored: no loopback URL bound (use --bind auto)");
}
}
let (shutdown_tx, _shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
{
let shutdown_tx = shutdown_tx.clone();
tokio::spawn(async move {
wait_for_shutdown_signal().await;
eprintln!("shutting down");
let _ = shutdown_tx.send(());
});
}
let tcp_task = if let Some(listener) = tcp_listener {
let app = app.clone();
let mut rx = shutdown_tx.subscribe();
Some(tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = rx.recv().await;
})
.await
}))
} else {
None
};
#[cfg(unix)]
let unix_task = if let Some(path) = unix_socket_path.clone() {
let app = app.clone();
let mut rx = shutdown_tx.subscribe();
let listener = bind_unix_listener(&path)?;
Some(tokio::spawn(async move {
serve_unix(listener, app, async move {
let _ = rx.recv().await;
})
.await
}))
} else {
None
};
#[cfg(not(unix))]
let unix_task: Option<tokio::task::JoinHandle<std::io::Result<()>>> = None;
let mut first_err: Option<String> = None;
if let Some(t) = tcp_task {
match t.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
let _ = first_err.get_or_insert_with(|| format!("tcp: {e}"));
}
Err(e) => {
let _ = first_err.get_or_insert_with(|| format!("tcp join: {e}"));
}
}
}
if let Some(t) = unix_task {
match t.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
let _ = first_err.get_or_insert_with(|| format!("unix: {e}"));
}
Err(e) => {
let _ = first_err.get_or_insert_with(|| format!("unix join: {e}"));
}
}
}
if let Some(p) = unix_socket_path.as_ref() {
let _ = std::fs::remove_file(p);
}
if let Some(e) = first_err {
return Err(CtlError::api(format!("webui server: {e}")));
}
Ok(())
}
fn resolve_bind_plan(bind: BindMode) -> CtlResult<(bool, bool)> {
#[cfg(unix)]
{
Ok(match bind {
BindMode::Auto => (true, true),
BindMode::Loopback => (true, false),
BindMode::Unix => (false, true),
})
}
#[cfg(not(unix))]
{
Ok(match bind {
BindMode::Auto => (true, false),
BindMode::Loopback => (true, false),
BindMode::Unix => {
return Err(CtlError::usage(
"--bind unix is not supported on Windows; use --bind loopback (the default)",
));
}
})
}
}
#[cfg(unix)]
fn unix_socket_path_for_pid(pid: u32) -> PathBuf {
let dir = if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
PathBuf::from(xdg)
} else {
let uid = unsafe { libc::getuid() };
PathBuf::from("/tmp").join(format!("cellctl-{uid}"))
};
dir.join(format!("cellctl-webui-{pid}.sock"))
}
#[cfg(unix)]
fn bind_unix_listener(path: &std::path::Path) -> CtlResult<tokio::net::UnixListener> {
use std::os::unix::fs::PermissionsExt;
if let Some(parent) = path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent)
.map_err(|e| CtlError::usage(format!("mkdir {}: {e}", parent.display())))?;
}
let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
}
let _ = std::fs::remove_file(path);
let prev_umask = unsafe { libc::umask(0o077) };
let bind_result = tokio::net::UnixListener::bind(path);
unsafe {
let _ = libc::umask(prev_umask);
}
let listener =
bind_result.map_err(|e| CtlError::usage(format!("bind {}: {e}", path.display())))?;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(path, perms)
.map_err(|e| CtlError::usage(format!("chmod 0600 {}: {e}", path.display())))?;
Ok(listener)
}
async fn wait_for_shutdown_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut sigint = match signal(SignalKind::interrupt()) {
Ok(s) => s,
Err(_) => {
let _ = tokio::signal::ctrl_c().await;
return;
}
};
let mut sigterm = match signal(SignalKind::terminate()) {
Ok(s) => s,
Err(_) => {
let _ = sigint.recv().await;
return;
}
};
tokio::select! {
_ = sigint.recv() => {},
_ = sigterm.recv() => {},
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
}
fn is_safe_open_url(s: &str) -> bool {
if s.chars().any(|c| c.is_control()) {
return false;
}
let Ok(u) = url::Url::parse(s) else {
return false;
};
let scheme = u.scheme();
if scheme != "http" && scheme != "https" {
return false;
}
let Some(host) = u.host_str() else {
return false;
};
matches!(host, "127.0.0.1" | "localhost" | "::1" | "[::1]")
}
#[cfg(unix)]
async fn serve_unix(
listener: tokio::net::UnixListener,
app: Router,
shutdown: impl std::future::Future<Output = ()>,
) -> std::io::Result<()> {
use std::convert::Infallible;
use tower::Service;
tokio::pin!(shutdown);
loop {
tokio::select! {
_ = &mut shutdown => return Ok(()),
accepted = listener.accept() => {
let (stream, _peer) = match accepted {
Ok(s) => s,
Err(e) => {
eprintln!("webui: unix accept error: {e}");
continue;
}
};
let app = app.clone();
tokio::spawn(async move {
let svc = hyper::service::service_fn(move |req: http::Request<hyper::body::Incoming>| {
let mut router = app.clone();
async move {
let resp: Response = match router.call(req.map(Body::new)).await {
Ok(r) => r,
Err(never) => match never {},
};
Ok::<_, Infallible>(resp)
}
});
let io = hyper_util::rt::TokioIo::new(stream);
let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection_with_upgrades(io, svc)
.await;
});
}
}
}
}
fn mint_session_token() -> String {
let mut buf = [0u8; 32];
rand::thread_rng().fill_bytes(&mut buf);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf)
}
fn resolve_bundle_dir() -> CtlResult<PathBuf> {
if let Ok(p) = std::env::var("CELLCTL_WEBUI_BUNDLE_DIR") {
return Ok(PathBuf::from(p));
}
let from_manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(BUNDLE_DIR_RELATIVE);
Ok(from_manifest)
}
fn build_router(state: AppState) -> Router {
let bundle_dir = state.bundle_dir.as_ref().clone();
let serve_dir = ServeDir::new(&bundle_dir).append_index_html_on_directories(true);
Router::new()
.route("/auth/exchange", post(auth_exchange))
.route("/v1/*rest", any(proxy_v1))
.route("/ws/events", any(ws_events))
.fallback_service(serve_dir)
.layer(axum::middleware::from_fn(reject_non_get))
.with_state(state)
}
async fn reject_non_get(req: axum::http::Request<Body>, next: axum::middleware::Next) -> Response {
let method = req.method().clone();
let path = req.uri().path().to_string();
let is_auth_exchange = method == Method::POST && path == "/auth/exchange";
if method != Method::GET && !is_auth_exchange {
return method_not_allowed();
}
next.run(req).await
}
fn method_not_allowed() -> Response {
let mut resp = (StatusCode::METHOD_NOT_ALLOWED, "method not allowed\n").into_response();
resp.headers_mut()
.insert(header::ALLOW, HeaderValue::from_static("GET"));
resp
}
const COOKIE_MAX_AGE_SECS: u64 = 86_400;
async fn auth_exchange(State(state): State<AppState>, headers: HeaderMap, body: Bytes) -> Response {
use std::sync::atomic::Ordering;
use subtle::ConstantTimeEq;
let ct_ok = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| {
let trimmed = s.split(';').next().unwrap_or("").trim();
trimmed.eq_ignore_ascii_case("application/json")
})
.unwrap_or(false);
if !ct_ok {
let mut resp = (
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"Content-Type must be application/json\n",
)
.into_response();
resp.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/plain; charset=utf-8"),
);
return resp;
}
let parsed: ExchangeRequest = match serde_json::from_slice(&body) {
Ok(p) => p,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid json body\n").into_response(),
};
if state.exchange_consumed.load(Ordering::SeqCst) {
return (
StatusCode::UNAUTHORIZED,
"session token already exchanged\n",
)
.into_response();
}
let provided = parsed.sess.as_bytes();
let expected = state.session_token.as_bytes();
if provided.len() != expected.len() || provided.ct_eq(expected).unwrap_u8() == 0 {
return (StatusCode::UNAUTHORIZED, "bad session token\n").into_response();
}
if state
.exchange_consumed
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return (
StatusCode::UNAUTHORIZED,
"session token already exchanged\n",
)
.into_response();
}
let mut buf = [0u8; 32];
rand::thread_rng().fill_bytes(&mut buf);
let cookie_value = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf);
{
let mut slot = state.session_cookie.write().await;
*slot = Some(cookie_value.clone());
}
let cookie_header = format!(
"cellctl_session={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
cookie_value, COOKIE_MAX_AGE_SECS
);
let mut resp = (StatusCode::OK, "ok\n").into_response();
resp.headers_mut().insert(
header::SET_COOKIE,
HeaderValue::from_str(&cookie_header).unwrap(),
);
resp
}
async fn require_session_cookie(state: &AppState, headers: &HeaderMap) -> bool {
let expected = match state.session_cookie.read().await.clone() {
Some(v) => v,
None => return false,
};
let Some(cookie_hdr) = headers.get(header::COOKIE) else {
return false;
};
let Ok(cookie_str) = cookie_hdr.to_str() else {
return false;
};
for entry in cookie_str.split(';') {
let entry = entry.trim();
if let Some(v) = entry.strip_prefix("cellctl_session=") {
return v == expected;
}
}
false
}
async fn proxy_v1(
State(state): State<AppState>,
AxumPath(rest): AxumPath<String>,
uri: Uri,
headers: HeaderMap,
) -> Response {
if !require_session_cookie(&state, &headers).await {
return unauthorized();
}
let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
let upstream_url = format!("{}/v1/{}{}", state.upstream_base, rest, query);
let client = match reqwest::Client::builder().build() {
Ok(c) => c,
Err(e) => return upstream_error(format!("client: {e}")),
};
let mut req = client.get(&upstream_url);
if let Some(tok) = state.upstream_bearer.as_ref() {
req = req.header(reqwest::header::AUTHORIZATION, format!("Bearer {tok}"));
}
let upstream_resp = match req.send().await {
Ok(r) => r,
Err(e) => return upstream_error(format!("send: {e}")),
};
let status = upstream_resp.status();
let resp_headers = upstream_resp.headers().clone();
let body_bytes = match upstream_resp.bytes().await {
Ok(b) => b,
Err(e) => return upstream_error(format!("read body: {e}")),
};
let mut out = Response::new(Body::from(body_bytes));
*out.status_mut() = StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
if let Some(ct) = resp_headers.get(reqwest::header::CONTENT_TYPE) {
if let Ok(v) = HeaderValue::from_bytes(ct.as_bytes()) {
out.headers_mut().insert(header::CONTENT_TYPE, v);
}
}
out
}
async fn ws_events(
State(state): State<AppState>,
uri: Uri,
headers: HeaderMap,
ws: Option<WebSocketUpgrade>,
) -> Response {
if !require_session_cookie(&state, &headers).await {
return unauthorized();
}
let Some(ws) = ws else {
return (StatusCode::BAD_REQUEST, "expected websocket upgrade\n").into_response();
};
let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
let ws_url = {
let base = state.upstream_base.as_str();
let ws_base = if let Some(rest) = base.strip_prefix("https://") {
format!("wss://{rest}")
} else if let Some(rest) = base.strip_prefix("http://") {
format!("ws://{rest}")
} else {
base.to_string()
};
format!("{ws_base}/ws/events{query}")
};
let bearer = state.upstream_bearer.as_ref().clone();
let subprotocols: Option<String> = headers
.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
ws.on_upgrade(move |client_ws| async move {
let (mut client_tx, mut client_rx) = client_ws.split();
let mut request =
match tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(
ws_url.as_str(),
) {
Ok(r) => r,
Err(_) => return,
};
if let Some(tok) = bearer {
if let Ok(v) = tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&format!(
"Bearer {tok}"
)) {
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::header::AUTHORIZATION,
v,
);
}
}
if let Some(proto) = subprotocols.as_ref() {
if let Ok(v) = tokio_tungstenite::tungstenite::http::HeaderValue::from_str(proto) {
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL,
v,
);
}
}
let (upstream_ws, _) = match tokio_tungstenite::connect_async(request).await {
Ok(p) => p,
Err(_) => return,
};
let (mut up_tx, mut up_rx) = upstream_ws.split();
loop {
tokio::select! {
msg = client_rx.next() => match msg {
Some(Ok(axum::extract::ws::Message::Text(t))) => {
let _ = up_tx
.send(tokio_tungstenite::tungstenite::Message::Text(t))
.await;
}
Some(Ok(axum::extract::ws::Message::Binary(b))) => {
let _ = up_tx
.send(tokio_tungstenite::tungstenite::Message::Binary(b))
.await;
}
Some(Ok(axum::extract::ws::Message::Ping(p))) => {
let _ = up_tx
.send(tokio_tungstenite::tungstenite::Message::Ping(p))
.await;
}
Some(Ok(axum::extract::ws::Message::Pong(p))) => {
let _ = up_tx
.send(tokio_tungstenite::tungstenite::Message::Pong(p))
.await;
}
Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
Some(Err(_)) => break,
},
msg = up_rx.next() => match msg {
Some(Ok(tokio_tungstenite::tungstenite::Message::Text(t))) => {
let _ = client_tx
.send(axum::extract::ws::Message::Text(t))
.await;
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(b))) => {
let _ = client_tx
.send(axum::extract::ws::Message::Binary(b))
.await;
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(p))) => {
let _ = client_tx
.send(axum::extract::ws::Message::Ping(p))
.await;
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Pong(p))) => {
let _ = client_tx
.send(axum::extract::ws::Message::Pong(p))
.await;
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => break,
Some(Ok(tokio_tungstenite::tungstenite::Message::Frame(_))) => {}
Some(Err(_)) => break,
},
}
}
})
}
fn unauthorized() -> Response {
(StatusCode::UNAUTHORIZED, "missing session cookie\n").into_response()
}
fn upstream_error(msg: String) -> Response {
(StatusCode::BAD_GATEWAY, format!("upstream: {msg}\n")).into_response()
}
#[allow(dead_code)]
const _HEADER_NAME_KEEP: fn() = || {
let _: HeaderName = HeaderName::from_static("x-keep");
};
#[cfg(test)]
mod tests {
use super::*;
use axum::body::to_bytes;
use axum::http::Request;
use tower::ServiceExt;
fn test_state(bundle_dir: PathBuf) -> AppState {
AppState {
upstream_base: Arc::new("http://127.0.0.1:0".to_string()),
upstream_bearer: Arc::new(None),
session_token: Arc::new("test-token".to_string()),
session_cookie: Arc::new(RwLock::new(None)),
exchange_consumed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
bundle_dir: Arc::new(bundle_dir),
}
}
fn ensure_bundle_dir() -> PathBuf {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let real = manifest_dir.join("static");
if real.join("index.html").exists() {
return real;
}
let tmp = std::env::temp_dir().join(format!("cellctl-webui-test-{}", std::process::id()));
std::fs::create_dir_all(&tmp).expect("mkdir tmp bundle");
std::fs::write(
tmp.join("index.html"),
"<!doctype html><title>cellctl webui</title>",
)
.expect("write index.html");
tmp
}
#[tokio::test]
async fn serves_index_at_root() {
let bundle = ensure_bundle_dir();
let app = build_router(test_state(bundle));
let resp = app
.oneshot(
Request::builder()
.method("GET")
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), 64 * 1024).await.unwrap();
let body_str = std::str::from_utf8(&body).unwrap();
assert!(
body_str.to_ascii_lowercase().contains("<!doctype html")
|| body_str.to_ascii_lowercase().contains("<html"),
"expected HTML at /, got: {body_str:.200}"
);
}
#[tokio::test]
async fn non_get_returns_405() {
let bundle = ensure_bundle_dir();
let app = build_router(test_state(bundle));
let resp = app
.oneshot(
Request::builder()
.method("DELETE")
.uri("/v1/formations")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(
resp.headers().get(header::ALLOW).map(|v| v.as_bytes()),
Some(b"GET" as &[u8]),
);
}
#[tokio::test]
async fn put_to_v1_returns_405() {
let bundle = ensure_bundle_dir();
let app = build_router(test_state(bundle));
let resp = app
.oneshot(
Request::builder()
.method("PUT")
.uri("/v1/formations/foo")
.body(Body::from("body"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn proxy_without_cookie_returns_401() {
let bundle = ensure_bundle_dir();
let app = build_router(test_state(bundle));
let resp = app
.oneshot(
Request::builder()
.method("GET")
.uri("/v1/formations")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn auth_exchange_with_wrong_token_returns_401() {
let bundle = ensure_bundle_dir();
let app = build_router(test_state(bundle));
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/auth/exchange")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"sess":"wrong"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn auth_exchange_with_right_token_sets_cookie() {
let bundle = ensure_bundle_dir();
let state = test_state(bundle);
let app = build_router(state.clone());
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/auth/exchange")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"sess":"test-token"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let cookie = resp
.headers()
.get(header::SET_COOKIE)
.expect("Set-Cookie header present")
.to_str()
.unwrap()
.to_string();
assert!(cookie.starts_with("cellctl_session="));
assert!(cookie.contains("HttpOnly"));
assert!(cookie.contains("SameSite=Strict"));
let stored = state.session_cookie.read().await.clone();
assert!(stored.is_some());
let stored = stored.unwrap();
assert!(cookie.contains(&format!("cellctl_session={stored}")));
}
#[tokio::test]
async fn proxy_with_valid_cookie_attempts_upstream() {
let bundle = ensure_bundle_dir();
let state = test_state(bundle);
let app = build_router(state.clone());
let exch = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/auth/exchange")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"sess":"test-token"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(exch.status(), StatusCode::OK);
let cookie_hdr = exch
.headers()
.get(header::SET_COOKIE)
.unwrap()
.to_str()
.unwrap()
.to_string();
let cookie_pair = cookie_hdr.split(';').next().unwrap().trim().to_string();
let resp = app
.oneshot(
Request::builder()
.method("GET")
.uri("/v1/formations")
.header(header::COOKIE, cookie_pair)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_ne!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn proxy_only_emits_get_to_upstream() {
use std::sync::Mutex;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let mock = TcpListener::bind("127.0.0.1:0").await.expect("bind mock");
let mock_addr = mock.local_addr().expect("mock addr");
let methods_seen: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let methods_for_task = methods_seen.clone();
tokio::spawn(async move {
loop {
let (mut stream, _) = match mock.accept().await {
Ok(p) => p,
Err(_) => return,
};
let methods = methods_for_task.clone();
tokio::spawn(async move {
let mut buf = [0u8; 4096];
let n = match stream.read(&mut buf).await {
Ok(n) => n,
Err(_) => return,
};
let head = String::from_utf8_lossy(&buf[..n]).to_string();
if let Some(first_line) = head.lines().next() {
if let Some(method) = first_line.split_whitespace().next() {
methods.lock().unwrap().push(method.to_string());
}
}
let _ = stream
.write_all(
b"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
Content-Length: 2\r\n\
Connection: close\r\n\
\r\n\
{}",
)
.await;
let _ = stream.shutdown().await;
});
}
});
let bundle = ensure_bundle_dir();
let state = AppState {
upstream_base: Arc::new(format!("http://{}", mock_addr)),
upstream_bearer: Arc::new(Some("test-bearer".to_string())),
session_token: Arc::new("test-token".to_string()),
session_cookie: Arc::new(RwLock::new(None)),
exchange_consumed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
bundle_dir: Arc::new(bundle),
};
let app = build_router(state.clone());
let exch = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/auth/exchange")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"sess":"test-token"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(exch.status(), StatusCode::OK);
let cookie_pair = exch
.headers()
.get(header::SET_COOKIE)
.unwrap()
.to_str()
.unwrap()
.split(';')
.next()
.unwrap()
.trim()
.to_string();
let non_get_methods = ["POST", "PUT", "DELETE", "PATCH"];
for m in non_get_methods {
let resp = app
.clone()
.oneshot(
Request::builder()
.method(m)
.uri("/v1/formations")
.header(header::COOKIE, cookie_pair.clone())
.body(Body::from("payload that must never reach upstream"))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
resp.status(),
StatusCode::METHOD_NOT_ALLOWED,
"inbound {m} should be 405"
);
assert_eq!(
resp.headers().get(header::ALLOW).map(|v| v.as_bytes()),
Some(b"GET" as &[u8]),
"405 response for {m} must carry `Allow: GET`"
);
}
let get_resp = app
.clone()
.oneshot(
Request::builder()
.method("GET")
.uri("/v1/formations")
.header(header::COOKIE, cookie_pair.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_ne!(
get_resp.status(),
StatusCode::METHOD_NOT_ALLOWED,
"GET must pass the 405 middleware"
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let observed = methods_seen.lock().unwrap().clone();
let non_get_count = observed.iter().filter(|m| m.as_str() != "GET").count();
assert_eq!(
non_get_count, 0,
"upstream saw non-GET method(s): {observed:?}"
);
assert!(
observed.iter().any(|m| m == "GET"),
"expected at least one GET to reach upstream, saw {observed:?}"
);
}
}