1use std::net::SocketAddr;
35use std::path::PathBuf;
36use std::sync::Arc;
37
38use axum::body::{Body, Bytes};
39use axum::extract::{ws::WebSocketUpgrade, Path as AxumPath, State};
40use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri};
41use axum::response::{IntoResponse, Response};
42use axum::routing::{any, post};
43use axum::Router;
44use base64::Engine as _;
45use clap::ValueEnum;
46use futures_util::{SinkExt, StreamExt};
47use rand::RngCore;
48use serde::Deserialize;
49use tokio::sync::RwLock;
50use tower_http::services::ServeDir;
51
52use crate::client::CellosClient;
53use crate::config::Config;
54use crate::exit::{CtlError, CtlResult};
55
56const BUNDLE_DIR_RELATIVE: &str = "static";
61
62#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, ValueEnum)]
64pub enum BindMode {
65 #[default]
68 Auto,
69 Loopback,
71 Unix,
75}
76
77#[derive(Clone)]
87struct AppState {
88 upstream_base: Arc<String>,
90 upstream_bearer: Arc<Option<String>>,
92 session_token: Arc<String>,
94 session_cookie: Arc<RwLock<Option<String>>>,
97 exchange_consumed: Arc<std::sync::atomic::AtomicBool>,
102 bundle_dir: Arc<PathBuf>,
104}
105
106#[derive(Deserialize)]
107struct ExchangeRequest {
108 sess: String,
109}
110
111pub async fn run(cfg: &Config, open: bool, bind: BindMode) -> CtlResult<()> {
113 let _ = CellosClient::new(cfg)?;
118
119 let upstream_base = cfg.effective_server().trim_end_matches('/').to_string();
120 let upstream_bearer = cfg.effective_token();
121
122 let bundle_dir = resolve_bundle_dir()?;
123 if !bundle_dir.join("index.html").exists() {
124 return Err(CtlError::usage(format!(
125 "webui bundle not found at {}/index.html — run `npm --prefix web run build` first",
126 bundle_dir.display()
127 )));
128 }
129
130 let session_token = mint_session_token();
131
132 let upstream_log = upstream_base.clone();
133 let state = AppState {
134 upstream_base: Arc::new(upstream_base),
135 upstream_bearer: Arc::new(upstream_bearer),
136 session_token: Arc::new(session_token.clone()),
137 session_cookie: Arc::new(RwLock::new(None)),
138 exchange_consumed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
139 bundle_dir: Arc::new(bundle_dir.clone()),
140 };
141
142 let app = build_router(state);
143
144 let (want_tcp, want_unix) = resolve_bind_plan(bind)?;
146
147 let tcp_listener = if want_tcp {
148 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
149 Some(
150 tokio::net::TcpListener::bind(addr)
151 .await
152 .map_err(|e| CtlError::usage(format!("bind 127.0.0.1: {e}")))?,
153 )
154 } else {
155 None
156 };
157
158 #[cfg(unix)]
159 let unix_socket_path: Option<PathBuf> = if want_unix {
160 Some(unix_socket_path_for_pid(std::process::id()))
161 } else {
162 None
163 };
164 #[cfg(not(unix))]
165 let unix_socket_path: Option<PathBuf> = None;
166
167 let browser_url = if let Some(l) = tcp_listener.as_ref() {
168 let local_addr = l
169 .local_addr()
170 .map_err(|e| CtlError::usage(format!("local_addr: {e}")))?;
171 Some(format!("http://{}/#sess={}", local_addr, session_token))
172 } else {
173 None
174 };
175
176 if let Some(url) = browser_url.as_ref() {
177 println!("cellctl webui: {}", url);
178 }
179 if let Some(p) = unix_socket_path.as_ref() {
180 println!("cellctl webui: unix://{}", p.display());
181 }
182 if browser_url.is_none() {
183 eprintln!(
185 "cellctl webui: --bind unix has no browser-reachable URL; \
186 use a forwarder (e.g. `socat TCP-LISTEN:0,reuseaddr,fork UNIX-CONNECT:{}`) \
187 or rerun with `--bind auto` for a loopback URL.",
188 unix_socket_path
189 .as_ref()
190 .map(|p| p.display().to_string())
191 .unwrap_or_else(|| "<socket>".to_string())
192 );
193 }
194 println!("upstream: {}", upstream_log);
198 println!("press Ctrl-C to stop");
199
200 if open {
201 if let Some(url) = browser_url.as_ref() {
202 if !is_safe_open_url(url) {
210 eprintln!(
211 "cellctl webui: refusing to open URL (failed loopback-http sanity check)"
212 );
213 } else if let Err(e) = opener::open(url) {
214 eprintln!("cellctl webui: could not launch browser: {e}");
215 }
216 } else {
217 eprintln!("cellctl webui: --open ignored: no loopback URL bound (use --bind auto)");
218 }
219 }
220
221 let (shutdown_tx, _shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
228 {
229 let shutdown_tx = shutdown_tx.clone();
230 tokio::spawn(async move {
231 wait_for_shutdown_signal().await;
232 eprintln!("shutting down");
233 let _ = shutdown_tx.send(());
234 });
235 }
236
237 let tcp_task = if let Some(listener) = tcp_listener {
239 let app = app.clone();
240 let mut rx = shutdown_tx.subscribe();
241 Some(tokio::spawn(async move {
242 axum::serve(listener, app)
243 .with_graceful_shutdown(async move {
244 let _ = rx.recv().await;
245 })
246 .await
247 }))
248 } else {
249 None
250 };
251
252 #[cfg(unix)]
254 let unix_task = if let Some(path) = unix_socket_path.clone() {
255 let app = app.clone();
256 let mut rx = shutdown_tx.subscribe();
257 let listener = bind_unix_listener(&path)?;
258 Some(tokio::spawn(async move {
259 serve_unix(listener, app, async move {
260 let _ = rx.recv().await;
261 })
262 .await
263 }))
264 } else {
265 None
266 };
267 #[cfg(not(unix))]
268 let unix_task: Option<tokio::task::JoinHandle<std::io::Result<()>>> = None;
269
270 let mut first_err: Option<String> = None;
272 if let Some(t) = tcp_task {
273 match t.await {
274 Ok(Ok(())) => {}
275 Ok(Err(e)) => {
276 let _ = first_err.get_or_insert_with(|| format!("tcp: {e}"));
277 }
278 Err(e) => {
279 let _ = first_err.get_or_insert_with(|| format!("tcp join: {e}"));
280 }
281 }
282 }
283 if let Some(t) = unix_task {
284 match t.await {
285 Ok(Ok(())) => {}
286 Ok(Err(e)) => {
287 let _ = first_err.get_or_insert_with(|| format!("unix: {e}"));
288 }
289 Err(e) => {
290 let _ = first_err.get_or_insert_with(|| format!("unix join: {e}"));
291 }
292 }
293 }
294
295 if let Some(p) = unix_socket_path.as_ref() {
298 let _ = std::fs::remove_file(p);
299 }
300
301 if let Some(e) = first_err {
302 return Err(CtlError::api(format!("webui server: {e}")));
303 }
304 Ok(())
305}
306
307fn resolve_bind_plan(bind: BindMode) -> CtlResult<(bool, bool)> {
312 #[cfg(unix)]
313 {
314 Ok(match bind {
315 BindMode::Auto => (true, true),
316 BindMode::Loopback => (true, false),
317 BindMode::Unix => (false, true),
318 })
319 }
320 #[cfg(not(unix))]
321 {
322 Ok(match bind {
323 BindMode::Auto => (true, false),
324 BindMode::Loopback => (true, false),
325 BindMode::Unix => {
326 return Err(CtlError::usage(
327 "--bind unix is not supported on Windows; use --bind loopback (the default)",
328 ));
329 }
330 })
331 }
332}
333
334#[cfg(unix)]
344fn unix_socket_path_for_pid(pid: u32) -> PathBuf {
345 let dir = if let Some(xdg) = std::env::var_os("XDG_RUNTIME_DIR") {
346 PathBuf::from(xdg)
347 } else {
348 let uid = unsafe { libc::getuid() };
351 PathBuf::from("/tmp").join(format!("cellctl-{uid}"))
352 };
353 dir.join(format!("cellctl-webui-{pid}.sock"))
354}
355
356#[cfg(unix)]
375fn bind_unix_listener(path: &std::path::Path) -> CtlResult<tokio::net::UnixListener> {
376 use std::os::unix::fs::PermissionsExt;
377
378 if let Some(parent) = path.parent() {
380 if !parent.exists() {
381 std::fs::create_dir_all(parent)
382 .map_err(|e| CtlError::usage(format!("mkdir {}: {e}", parent.display())))?;
383 }
384 let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
388 }
389
390 let _ = std::fs::remove_file(path);
392
393 let prev_umask = unsafe { libc::umask(0o077) };
399 let bind_result = tokio::net::UnixListener::bind(path);
400 unsafe {
401 let _ = libc::umask(prev_umask);
402 }
403 let listener =
404 bind_result.map_err(|e| CtlError::usage(format!("bind {}: {e}", path.display())))?;
405
406 let perms = std::fs::Permissions::from_mode(0o600);
411 std::fs::set_permissions(path, perms)
412 .map_err(|e| CtlError::usage(format!("chmod 0600 {}: {e}", path.display())))?;
413
414 Ok(listener)
415}
416
417async fn wait_for_shutdown_signal() {
420 #[cfg(unix)]
421 {
422 use tokio::signal::unix::{signal, SignalKind};
423 let mut sigint = match signal(SignalKind::interrupt()) {
424 Ok(s) => s,
425 Err(_) => {
426 let _ = tokio::signal::ctrl_c().await;
427 return;
428 }
429 };
430 let mut sigterm = match signal(SignalKind::terminate()) {
431 Ok(s) => s,
432 Err(_) => {
433 let _ = sigint.recv().await;
434 return;
435 }
436 };
437 tokio::select! {
438 _ = sigint.recv() => {},
439 _ = sigterm.recv() => {},
440 }
441 }
442 #[cfg(not(unix))]
443 {
444 let _ = tokio::signal::ctrl_c().await;
445 }
446}
447
448fn is_safe_open_url(s: &str) -> bool {
458 if s.chars().any(|c| c.is_control()) {
461 return false;
462 }
463 let Ok(u) = url::Url::parse(s) else {
464 return false;
465 };
466 let scheme = u.scheme();
467 if scheme != "http" && scheme != "https" {
468 return false;
469 }
470 let Some(host) = u.host_str() else {
471 return false;
472 };
473 matches!(host, "127.0.0.1" | "localhost" | "::1" | "[::1]")
476}
477
478#[cfg(unix)]
484async fn serve_unix(
485 listener: tokio::net::UnixListener,
486 app: Router,
487 shutdown: impl std::future::Future<Output = ()>,
488) -> std::io::Result<()> {
489 use std::convert::Infallible;
490 use tower::Service;
491
492 tokio::pin!(shutdown);
497
498 loop {
499 tokio::select! {
500 _ = &mut shutdown => return Ok(()),
501 accepted = listener.accept() => {
502 let (stream, _peer) = match accepted {
503 Ok(s) => s,
504 Err(e) => {
505 eprintln!("webui: unix accept error: {e}");
507 continue;
508 }
509 };
510 let app = app.clone();
511 tokio::spawn(async move {
512 let svc = hyper::service::service_fn(move |req: http::Request<hyper::body::Incoming>| {
516 let mut router = app.clone();
517 async move {
518 let resp: Response = match router.call(req.map(Body::new)).await {
519 Ok(r) => r,
520 Err(never) => match never {},
521 };
522 Ok::<_, Infallible>(resp)
523 }
524 });
525 let io = hyper_util::rt::TokioIo::new(stream);
526 let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
527 .serve_connection_with_upgrades(io, svc)
528 .await;
529 });
530 }
531 }
532 }
533}
534
535fn mint_session_token() -> String {
537 let mut buf = [0u8; 32];
538 rand::thread_rng().fill_bytes(&mut buf);
539 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf)
540}
541
542fn resolve_bundle_dir() -> CtlResult<PathBuf> {
546 if let Ok(p) = std::env::var("CELLCTL_WEBUI_BUNDLE_DIR") {
547 return Ok(PathBuf::from(p));
548 }
549 let from_manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(BUNDLE_DIR_RELATIVE);
552 Ok(from_manifest)
553}
554
555fn build_router(state: AppState) -> Router {
558 let bundle_dir = state.bundle_dir.as_ref().clone();
561 let serve_dir = ServeDir::new(&bundle_dir).append_index_html_on_directories(true);
562
563 Router::new()
564 .route("/auth/exchange", post(auth_exchange))
565 .route("/v1/*rest", any(proxy_v1))
566 .route("/ws/events", any(ws_events))
567 .fallback_service(serve_dir)
571 .layer(axum::middleware::from_fn(reject_non_get))
572 .with_state(state)
573}
574
575async fn reject_non_get(req: axum::http::Request<Body>, next: axum::middleware::Next) -> Response {
579 let method = req.method().clone();
580 let path = req.uri().path().to_string();
581
582 let is_auth_exchange = method == Method::POST && path == "/auth/exchange";
586 if method != Method::GET && !is_auth_exchange {
587 return method_not_allowed();
588 }
589 next.run(req).await
590}
591
592fn method_not_allowed() -> Response {
593 let mut resp = (StatusCode::METHOD_NOT_ALLOWED, "method not allowed\n").into_response();
594 resp.headers_mut()
595 .insert(header::ALLOW, HeaderValue::from_static("GET"));
596 resp
597}
598
599const COOKIE_MAX_AGE_SECS: u64 = 86_400;
603
604async fn auth_exchange(State(state): State<AppState>, headers: HeaderMap, body: Bytes) -> Response {
624 use std::sync::atomic::Ordering;
625 use subtle::ConstantTimeEq;
626
627 let ct_ok = headers
631 .get(header::CONTENT_TYPE)
632 .and_then(|v| v.to_str().ok())
633 .map(|s| {
634 let trimmed = s.split(';').next().unwrap_or("").trim();
636 trimmed.eq_ignore_ascii_case("application/json")
637 })
638 .unwrap_or(false);
639 if !ct_ok {
640 let mut resp = (
641 StatusCode::UNSUPPORTED_MEDIA_TYPE,
642 "Content-Type must be application/json\n",
643 )
644 .into_response();
645 resp.headers_mut().insert(
646 header::CONTENT_TYPE,
647 HeaderValue::from_static("text/plain; charset=utf-8"),
648 );
649 return resp;
650 }
651
652 let parsed: ExchangeRequest = match serde_json::from_slice(&body) {
655 Ok(p) => p,
656 Err(_) => return (StatusCode::BAD_REQUEST, "invalid json body\n").into_response(),
657 };
658
659 if state.exchange_consumed.load(Ordering::SeqCst) {
662 return (
663 StatusCode::UNAUTHORIZED,
664 "session token already exchanged\n",
665 )
666 .into_response();
667 }
668
669 let provided = parsed.sess.as_bytes();
674 let expected = state.session_token.as_bytes();
675 if provided.len() != expected.len() || provided.ct_eq(expected).unwrap_u8() == 0 {
676 return (StatusCode::UNAUTHORIZED, "bad session token\n").into_response();
677 }
678
679 if state
684 .exchange_consumed
685 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
686 .is_err()
687 {
688 return (
689 StatusCode::UNAUTHORIZED,
690 "session token already exchanged\n",
691 )
692 .into_response();
693 }
694
695 let mut buf = [0u8; 32];
698 rand::thread_rng().fill_bytes(&mut buf);
699 let cookie_value = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf);
700
701 {
702 let mut slot = state.session_cookie.write().await;
703 *slot = Some(cookie_value.clone());
704 }
705
706 let cookie_header = format!(
709 "cellctl_session={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
710 cookie_value, COOKIE_MAX_AGE_SECS
711 );
712
713 let mut resp = (StatusCode::OK, "ok\n").into_response();
714 resp.headers_mut().insert(
715 header::SET_COOKIE,
716 HeaderValue::from_str(&cookie_header).unwrap(),
717 );
718 resp
719}
720
721async fn require_session_cookie(state: &AppState, headers: &HeaderMap) -> bool {
724 let expected = match state.session_cookie.read().await.clone() {
725 Some(v) => v,
726 None => return false,
727 };
728 let Some(cookie_hdr) = headers.get(header::COOKIE) else {
729 return false;
730 };
731 let Ok(cookie_str) = cookie_hdr.to_str() else {
732 return false;
733 };
734 for entry in cookie_str.split(';') {
736 let entry = entry.trim();
737 if let Some(v) = entry.strip_prefix("cellctl_session=") {
738 return v == expected;
739 }
740 }
741 false
742}
743
744async fn proxy_v1(
746 State(state): State<AppState>,
747 AxumPath(rest): AxumPath<String>,
748 uri: Uri,
749 headers: HeaderMap,
750) -> Response {
751 if !require_session_cookie(&state, &headers).await {
752 return unauthorized();
753 }
754
755 let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
756 let upstream_url = format!("{}/v1/{}{}", state.upstream_base, rest, query);
757
758 let client = match reqwest::Client::builder().build() {
759 Ok(c) => c,
760 Err(e) => return upstream_error(format!("client: {e}")),
761 };
762
763 let mut req = client.get(&upstream_url);
764 if let Some(tok) = state.upstream_bearer.as_ref() {
765 req = req.header(reqwest::header::AUTHORIZATION, format!("Bearer {tok}"));
766 }
767
768 let upstream_resp = match req.send().await {
769 Ok(r) => r,
770 Err(e) => return upstream_error(format!("send: {e}")),
771 };
772
773 let status = upstream_resp.status();
774 let resp_headers = upstream_resp.headers().clone();
775 let body_bytes = match upstream_resp.bytes().await {
776 Ok(b) => b,
777 Err(e) => return upstream_error(format!("read body: {e}")),
778 };
779
780 let mut out = Response::new(Body::from(body_bytes));
781 *out.status_mut() = StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
782 if let Some(ct) = resp_headers.get(reqwest::header::CONTENT_TYPE) {
784 if let Ok(v) = HeaderValue::from_bytes(ct.as_bytes()) {
785 out.headers_mut().insert(header::CONTENT_TYPE, v);
786 }
787 }
788 out
789}
790
791async fn ws_events(
793 State(state): State<AppState>,
794 uri: Uri,
795 headers: HeaderMap,
796 ws: Option<WebSocketUpgrade>,
797) -> Response {
798 if !require_session_cookie(&state, &headers).await {
799 return unauthorized();
800 }
801 let Some(ws) = ws else {
802 return (StatusCode::BAD_REQUEST, "expected websocket upgrade\n").into_response();
803 };
804
805 let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
806 let ws_url = {
807 let base = state.upstream_base.as_str();
809 let ws_base = if let Some(rest) = base.strip_prefix("https://") {
810 format!("wss://{rest}")
811 } else if let Some(rest) = base.strip_prefix("http://") {
812 format!("ws://{rest}")
813 } else {
814 base.to_string()
815 };
816 format!("{ws_base}/ws/events{query}")
817 };
818 let bearer = state.upstream_bearer.as_ref().clone();
819 let subprotocols: Option<String> = headers
832 .get(axum::http::header::SEC_WEBSOCKET_PROTOCOL)
833 .and_then(|v| v.to_str().ok())
834 .map(|s| s.to_string());
835
836 ws.on_upgrade(move |client_ws| async move {
837 let (mut client_tx, mut client_rx) = client_ws.split();
838
839 let mut request =
842 match tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(
843 ws_url.as_str(),
844 ) {
845 Ok(r) => r,
846 Err(_) => return,
847 };
848 if let Some(tok) = bearer {
849 if let Ok(v) = tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&format!(
850 "Bearer {tok}"
851 )) {
852 request.headers_mut().insert(
853 tokio_tungstenite::tungstenite::http::header::AUTHORIZATION,
854 v,
855 );
856 }
857 }
858 if let Some(proto) = subprotocols.as_ref() {
859 if let Ok(v) = tokio_tungstenite::tungstenite::http::HeaderValue::from_str(proto) {
860 request.headers_mut().insert(
861 tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL,
862 v,
863 );
864 }
865 }
866 let (upstream_ws, _) = match tokio_tungstenite::connect_async(request).await {
867 Ok(p) => p,
868 Err(_) => return,
869 };
870 let (mut up_tx, mut up_rx) = upstream_ws.split();
871
872 loop {
873 tokio::select! {
874 msg = client_rx.next() => match msg {
875 Some(Ok(axum::extract::ws::Message::Text(t))) => {
876 let _ = up_tx
877 .send(tokio_tungstenite::tungstenite::Message::Text(t))
878 .await;
879 }
880 Some(Ok(axum::extract::ws::Message::Binary(b))) => {
881 let _ = up_tx
882 .send(tokio_tungstenite::tungstenite::Message::Binary(b))
883 .await;
884 }
885 Some(Ok(axum::extract::ws::Message::Ping(p))) => {
886 let _ = up_tx
887 .send(tokio_tungstenite::tungstenite::Message::Ping(p))
888 .await;
889 }
890 Some(Ok(axum::extract::ws::Message::Pong(p))) => {
891 let _ = up_tx
892 .send(tokio_tungstenite::tungstenite::Message::Pong(p))
893 .await;
894 }
895 Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
896 Some(Err(_)) => break,
897 },
898 msg = up_rx.next() => match msg {
899 Some(Ok(tokio_tungstenite::tungstenite::Message::Text(t))) => {
900 let _ = client_tx
901 .send(axum::extract::ws::Message::Text(t))
902 .await;
903 }
904 Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(b))) => {
905 let _ = client_tx
906 .send(axum::extract::ws::Message::Binary(b))
907 .await;
908 }
909 Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(p))) => {
910 let _ = client_tx
911 .send(axum::extract::ws::Message::Ping(p))
912 .await;
913 }
914 Some(Ok(tokio_tungstenite::tungstenite::Message::Pong(p))) => {
915 let _ = client_tx
916 .send(axum::extract::ws::Message::Pong(p))
917 .await;
918 }
919 Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => break,
920 Some(Ok(tokio_tungstenite::tungstenite::Message::Frame(_))) => {}
921 Some(Err(_)) => break,
922 },
923 }
924 }
925 })
926}
927
928fn unauthorized() -> Response {
929 (StatusCode::UNAUTHORIZED, "missing session cookie\n").into_response()
930}
931
932fn upstream_error(msg: String) -> Response {
933 (StatusCode::BAD_GATEWAY, format!("upstream: {msg}\n")).into_response()
934}
935
936#[allow(dead_code)]
942const _HEADER_NAME_KEEP: fn() = || {
943 let _: HeaderName = HeaderName::from_static("x-keep");
946};
947
948#[cfg(test)]
953mod tests {
954 use super::*;
955 use axum::body::to_bytes;
956 use axum::http::Request;
957 use tower::ServiceExt; fn test_state(bundle_dir: PathBuf) -> AppState {
960 AppState {
961 upstream_base: Arc::new("http://127.0.0.1:0".to_string()),
962 upstream_bearer: Arc::new(None),
963 session_token: Arc::new("test-token".to_string()),
964 session_cookie: Arc::new(RwLock::new(None)),
965 exchange_consumed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
966 bundle_dir: Arc::new(bundle_dir),
967 }
968 }
969
970 fn ensure_bundle_dir() -> PathBuf {
973 let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
974 let real = manifest_dir.join("static");
975 if real.join("index.html").exists() {
976 return real;
977 }
978 let tmp = std::env::temp_dir().join(format!("cellctl-webui-test-{}", std::process::id()));
980 std::fs::create_dir_all(&tmp).expect("mkdir tmp bundle");
981 std::fs::write(
982 tmp.join("index.html"),
983 "<!doctype html><title>cellctl webui</title>",
984 )
985 .expect("write index.html");
986 tmp
987 }
988
989 #[tokio::test]
990 async fn serves_index_at_root() {
991 let bundle = ensure_bundle_dir();
992 let app = build_router(test_state(bundle));
993
994 let resp = app
995 .oneshot(
996 Request::builder()
997 .method("GET")
998 .uri("/")
999 .body(Body::empty())
1000 .unwrap(),
1001 )
1002 .await
1003 .unwrap();
1004 assert_eq!(resp.status(), StatusCode::OK);
1005 let body = to_bytes(resp.into_body(), 64 * 1024).await.unwrap();
1006 let body_str = std::str::from_utf8(&body).unwrap();
1007 assert!(
1008 body_str.to_ascii_lowercase().contains("<!doctype html")
1009 || body_str.to_ascii_lowercase().contains("<html"),
1010 "expected HTML at /, got: {body_str:.200}"
1011 );
1012 }
1013
1014 #[tokio::test]
1015 async fn non_get_returns_405() {
1016 let bundle = ensure_bundle_dir();
1017 let app = build_router(test_state(bundle));
1018
1019 let resp = app
1020 .oneshot(
1021 Request::builder()
1022 .method("DELETE")
1023 .uri("/v1/formations")
1024 .body(Body::empty())
1025 .unwrap(),
1026 )
1027 .await
1028 .unwrap();
1029 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1030 assert_eq!(
1031 resp.headers().get(header::ALLOW).map(|v| v.as_bytes()),
1032 Some(b"GET" as &[u8]),
1033 );
1034 }
1035
1036 #[tokio::test]
1037 async fn put_to_v1_returns_405() {
1038 let bundle = ensure_bundle_dir();
1039 let app = build_router(test_state(bundle));
1040
1041 let resp = app
1042 .oneshot(
1043 Request::builder()
1044 .method("PUT")
1045 .uri("/v1/formations/foo")
1046 .body(Body::from("body"))
1047 .unwrap(),
1048 )
1049 .await
1050 .unwrap();
1051 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
1052 }
1053
1054 #[tokio::test]
1055 async fn proxy_without_cookie_returns_401() {
1056 let bundle = ensure_bundle_dir();
1057 let app = build_router(test_state(bundle));
1058
1059 let resp = app
1060 .oneshot(
1061 Request::builder()
1062 .method("GET")
1063 .uri("/v1/formations")
1064 .body(Body::empty())
1065 .unwrap(),
1066 )
1067 .await
1068 .unwrap();
1069 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1070 }
1071
1072 #[tokio::test]
1073 async fn auth_exchange_with_wrong_token_returns_401() {
1074 let bundle = ensure_bundle_dir();
1075 let app = build_router(test_state(bundle));
1076
1077 let resp = app
1078 .oneshot(
1079 Request::builder()
1080 .method("POST")
1081 .uri("/auth/exchange")
1082 .header(header::CONTENT_TYPE, "application/json")
1083 .body(Body::from(r#"{"sess":"wrong"}"#))
1084 .unwrap(),
1085 )
1086 .await
1087 .unwrap();
1088 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1089 }
1090
1091 #[tokio::test]
1092 async fn auth_exchange_with_right_token_sets_cookie() {
1093 let bundle = ensure_bundle_dir();
1094 let state = test_state(bundle);
1095 let app = build_router(state.clone());
1096
1097 let resp = app
1098 .oneshot(
1099 Request::builder()
1100 .method("POST")
1101 .uri("/auth/exchange")
1102 .header(header::CONTENT_TYPE, "application/json")
1103 .body(Body::from(r#"{"sess":"test-token"}"#))
1104 .unwrap(),
1105 )
1106 .await
1107 .unwrap();
1108 assert_eq!(resp.status(), StatusCode::OK);
1109 let cookie = resp
1110 .headers()
1111 .get(header::SET_COOKIE)
1112 .expect("Set-Cookie header present")
1113 .to_str()
1114 .unwrap()
1115 .to_string();
1116 assert!(cookie.starts_with("cellctl_session="));
1117 assert!(cookie.contains("HttpOnly"));
1118 assert!(cookie.contains("SameSite=Strict"));
1119
1120 let stored = state.session_cookie.read().await.clone();
1122 assert!(stored.is_some());
1123 let stored = stored.unwrap();
1124 assert!(cookie.contains(&format!("cellctl_session={stored}")));
1125 }
1126
1127 #[tokio::test]
1128 async fn proxy_with_valid_cookie_attempts_upstream() {
1129 let bundle = ensure_bundle_dir();
1134 let state = test_state(bundle);
1135 let app = build_router(state.clone());
1136
1137 let exch = app
1138 .clone()
1139 .oneshot(
1140 Request::builder()
1141 .method("POST")
1142 .uri("/auth/exchange")
1143 .header(header::CONTENT_TYPE, "application/json")
1144 .body(Body::from(r#"{"sess":"test-token"}"#))
1145 .unwrap(),
1146 )
1147 .await
1148 .unwrap();
1149 assert_eq!(exch.status(), StatusCode::OK);
1150 let cookie_hdr = exch
1151 .headers()
1152 .get(header::SET_COOKIE)
1153 .unwrap()
1154 .to_str()
1155 .unwrap()
1156 .to_string();
1157 let cookie_pair = cookie_hdr.split(';').next().unwrap().trim().to_string();
1159
1160 let resp = app
1161 .oneshot(
1162 Request::builder()
1163 .method("GET")
1164 .uri("/v1/formations")
1165 .header(header::COOKIE, cookie_pair)
1166 .body(Body::empty())
1167 .unwrap(),
1168 )
1169 .await
1170 .unwrap();
1171 assert_ne!(resp.status(), StatusCode::UNAUTHORIZED);
1173 }
1174
1175 #[tokio::test]
1187 async fn proxy_only_emits_get_to_upstream() {
1188 use std::sync::Mutex;
1189 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1190 use tokio::net::TcpListener;
1191
1192 let mock = TcpListener::bind("127.0.0.1:0").await.expect("bind mock");
1194 let mock_addr = mock.local_addr().expect("mock addr");
1195 let methods_seen: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1196 let methods_for_task = methods_seen.clone();
1197
1198 tokio::spawn(async move {
1205 loop {
1206 let (mut stream, _) = match mock.accept().await {
1207 Ok(p) => p,
1208 Err(_) => return,
1209 };
1210 let methods = methods_for_task.clone();
1211 tokio::spawn(async move {
1212 let mut buf = [0u8; 4096];
1213 let n = match stream.read(&mut buf).await {
1214 Ok(n) => n,
1215 Err(_) => return,
1216 };
1217 let head = String::from_utf8_lossy(&buf[..n]).to_string();
1218 if let Some(first_line) = head.lines().next() {
1219 if let Some(method) = first_line.split_whitespace().next() {
1220 methods.lock().unwrap().push(method.to_string());
1221 }
1222 }
1223 let _ = stream
1224 .write_all(
1225 b"HTTP/1.1 200 OK\r\n\
1226 Content-Type: application/json\r\n\
1227 Content-Length: 2\r\n\
1228 Connection: close\r\n\
1229 \r\n\
1230 {}",
1231 )
1232 .await;
1233 let _ = stream.shutdown().await;
1234 });
1235 }
1236 });
1237
1238 let bundle = ensure_bundle_dir();
1240 let state = AppState {
1241 upstream_base: Arc::new(format!("http://{}", mock_addr)),
1242 upstream_bearer: Arc::new(Some("test-bearer".to_string())),
1243 session_token: Arc::new("test-token".to_string()),
1244 session_cookie: Arc::new(RwLock::new(None)),
1245 exchange_consumed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1246 bundle_dir: Arc::new(bundle),
1247 };
1248 let app = build_router(state.clone());
1249
1250 let exch = app
1254 .clone()
1255 .oneshot(
1256 Request::builder()
1257 .method("POST")
1258 .uri("/auth/exchange")
1259 .header(header::CONTENT_TYPE, "application/json")
1260 .body(Body::from(r#"{"sess":"test-token"}"#))
1261 .unwrap(),
1262 )
1263 .await
1264 .unwrap();
1265 assert_eq!(exch.status(), StatusCode::OK);
1266 let cookie_pair = exch
1267 .headers()
1268 .get(header::SET_COOKIE)
1269 .unwrap()
1270 .to_str()
1271 .unwrap()
1272 .split(';')
1273 .next()
1274 .unwrap()
1275 .trim()
1276 .to_string();
1277
1278 let non_get_methods = ["POST", "PUT", "DELETE", "PATCH"];
1281 for m in non_get_methods {
1282 let resp = app
1283 .clone()
1284 .oneshot(
1285 Request::builder()
1286 .method(m)
1287 .uri("/v1/formations")
1288 .header(header::COOKIE, cookie_pair.clone())
1289 .body(Body::from("payload that must never reach upstream"))
1290 .unwrap(),
1291 )
1292 .await
1293 .unwrap();
1294 assert_eq!(
1295 resp.status(),
1296 StatusCode::METHOD_NOT_ALLOWED,
1297 "inbound {m} should be 405"
1298 );
1299 assert_eq!(
1300 resp.headers().get(header::ALLOW).map(|v| v.as_bytes()),
1301 Some(b"GET" as &[u8]),
1302 "405 response for {m} must carry `Allow: GET`"
1303 );
1304 }
1305
1306 let get_resp = app
1308 .clone()
1309 .oneshot(
1310 Request::builder()
1311 .method("GET")
1312 .uri("/v1/formations")
1313 .header(header::COOKIE, cookie_pair.clone())
1314 .body(Body::empty())
1315 .unwrap(),
1316 )
1317 .await
1318 .unwrap();
1319 assert_ne!(
1321 get_resp.status(),
1322 StatusCode::METHOD_NOT_ALLOWED,
1323 "GET must pass the 405 middleware"
1324 );
1325
1326 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1328
1329 let observed = methods_seen.lock().unwrap().clone();
1331 let non_get_count = observed.iter().filter(|m| m.as_str() != "GET").count();
1334 assert_eq!(
1335 non_get_count, 0,
1336 "upstream saw non-GET method(s): {observed:?}"
1337 );
1338 assert!(
1339 observed.iter().any(|m| m == "GET"),
1340 "expected at least one GET to reach upstream, saw {observed:?}"
1341 );
1342 }
1343}