1use std::net::SocketAddr;
35use std::path::PathBuf;
36use std::sync::Arc;
37
38use axum::body::{Body, Bytes};
39use axum::extract::{ws::WebSocketUpgrade, Path as AxumPath, State};
40use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri};
41use axum::response::{IntoResponse, Response};
42use axum::routing::{any, post};
43use axum::{Json, Router};
44use base64::Engine as _;
45use clap::ValueEnum;
46use futures_util::{SinkExt, StreamExt};
47use rand::RngCore;
48use serde::Deserialize;
49use tokio::sync::RwLock;
50use tower_http::services::ServeDir;
51
52use crate::client::CellosClient;
53use crate::config::Config;
54use crate::exit::{CtlError, CtlResult};
55
56const BUNDLE_DIR_RELATIVE: &str = "static";
61
62#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, ValueEnum)]
64pub enum BindMode {
65 #[default]
68 Auto,
69 Loopback,
71 Unix,
75}
76
77#[derive(Clone)]
80struct AppState {
81 upstream_base: Arc<String>,
83 upstream_bearer: Arc<Option<String>>,
85 session_token: Arc<String>,
87 session_cookie: Arc<RwLock<Option<String>>>,
90 bundle_dir: Arc<PathBuf>,
92}
93
94#[derive(Deserialize)]
95struct ExchangeRequest {
96 sess: String,
97}
98
99pub async fn run(cfg: &Config, open: bool, bind: BindMode) -> CtlResult<()> {
101 let _ = CellosClient::new(cfg)?;
106
107 let upstream_base = cfg.effective_server().trim_end_matches('/').to_string();
108 let upstream_bearer = cfg.effective_token();
109
110 let bundle_dir = resolve_bundle_dir()?;
111 if !bundle_dir.join("index.html").exists() {
112 return Err(CtlError::usage(format!(
113 "webui bundle not found at {}/index.html — run `npm --prefix web run build` first",
114 bundle_dir.display()
115 )));
116 }
117
118 let session_token = mint_session_token();
119
120 let state = AppState {
121 upstream_base: Arc::new(upstream_base),
122 upstream_bearer: Arc::new(upstream_bearer),
123 session_token: Arc::new(session_token.clone()),
124 session_cookie: Arc::new(RwLock::new(None)),
125 bundle_dir: Arc::new(bundle_dir.clone()),
126 };
127
128 let app = build_router(state);
129
130 let (want_tcp, want_unix) = resolve_bind_plan(bind)?;
132
133 let tcp_listener = if want_tcp {
134 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
135 Some(
136 tokio::net::TcpListener::bind(addr)
137 .await
138 .map_err(|e| CtlError::usage(format!("bind 127.0.0.1: {e}")))?,
139 )
140 } else {
141 None
142 };
143
144 #[cfg(unix)]
145 let unix_socket_path: Option<PathBuf> = if want_unix {
146 Some(unix_socket_path_for_pid(std::process::id()))
147 } else {
148 None
149 };
150 #[cfg(not(unix))]
151 let unix_socket_path: Option<PathBuf> = None;
152
153 let browser_url = if let Some(l) = tcp_listener.as_ref() {
154 let local_addr = l
155 .local_addr()
156 .map_err(|e| CtlError::usage(format!("local_addr: {e}")))?;
157 Some(format!("http://{}/#sess={}", local_addr, session_token))
158 } else {
159 None
160 };
161
162 if let Some(url) = browser_url.as_ref() {
163 println!("cellctl webui: {}", url);
164 }
165 if let Some(p) = unix_socket_path.as_ref() {
166 println!("cellctl webui: unix://{}", p.display());
167 }
168 if browser_url.is_none() {
169 eprintln!(
171 "cellctl webui: --bind unix has no browser-reachable URL; \
172 use a forwarder (e.g. `socat TCP-LISTEN:0,reuseaddr,fork UNIX-CONNECT:{}`) \
173 or rerun with `--bind auto` for a loopback URL.",
174 unix_socket_path
175 .as_ref()
176 .map(|p| p.display().to_string())
177 .unwrap_or_else(|| "<socket>".to_string())
178 );
179 }
180 println!("upstream: {}", state_upstream_for_log(&app));
181 println!("press Ctrl-C to stop");
182
183 if open {
184 if let Some(url) = browser_url.as_ref() {
185 if let Err(e) = opener::open(url) {
186 eprintln!("cellctl webui: could not launch browser: {e}");
187 }
188 } else {
189 eprintln!("cellctl webui: --open ignored: no loopback URL bound (use --bind auto)");
190 }
191 }
192
193 let (shutdown_tx, _shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
197 {
198 let shutdown_tx = shutdown_tx.clone();
199 tokio::spawn(async move {
200 let _ = tokio::signal::ctrl_c().await;
201 eprintln!("shutting down");
202 let _ = shutdown_tx.send(());
203 });
204 }
205
206 let tcp_task = if let Some(listener) = tcp_listener {
208 let app = app.clone();
209 let mut rx = shutdown_tx.subscribe();
210 Some(tokio::spawn(async move {
211 axum::serve(listener, app)
212 .with_graceful_shutdown(async move {
213 let _ = rx.recv().await;
214 })
215 .await
216 }))
217 } else {
218 None
219 };
220
221 #[cfg(unix)]
223 let unix_task = if let Some(path) = unix_socket_path.clone() {
224 let app = app.clone();
225 let mut rx = shutdown_tx.subscribe();
226 let listener = bind_unix_listener(&path)?;
227 Some(tokio::spawn(async move {
228 serve_unix(listener, app, async move {
229 let _ = rx.recv().await;
230 })
231 .await
232 }))
233 } else {
234 None
235 };
236 #[cfg(not(unix))]
237 let unix_task: Option<tokio::task::JoinHandle<std::io::Result<()>>> = None;
238
239 let mut first_err: Option<String> = None;
241 if let Some(t) = tcp_task {
242 match t.await {
243 Ok(Ok(())) => {}
244 Ok(Err(e)) => {
245 let _ = first_err.get_or_insert_with(|| format!("tcp: {e}"));
246 }
247 Err(e) => {
248 let _ = first_err.get_or_insert_with(|| format!("tcp join: {e}"));
249 }
250 }
251 }
252 if let Some(t) = unix_task {
253 match t.await {
254 Ok(Ok(())) => {}
255 Ok(Err(e)) => {
256 let _ = first_err.get_or_insert_with(|| format!("unix: {e}"));
257 }
258 Err(e) => {
259 let _ = first_err.get_or_insert_with(|| format!("unix join: {e}"));
260 }
261 }
262 }
263
264 if let Some(p) = unix_socket_path.as_ref() {
267 let _ = std::fs::remove_file(p);
268 }
269
270 if let Some(e) = first_err {
271 return Err(CtlError::api(format!("webui server: {e}")));
272 }
273 Ok(())
274}
275
276fn resolve_bind_plan(bind: BindMode) -> CtlResult<(bool, bool)> {
281 #[cfg(unix)]
282 {
283 Ok(match bind {
284 BindMode::Auto => (true, true),
285 BindMode::Loopback => (true, false),
286 BindMode::Unix => (false, true),
287 })
288 }
289 #[cfg(not(unix))]
290 {
291 Ok(match bind {
292 BindMode::Auto => (true, false),
293 BindMode::Loopback => (true, false),
294 BindMode::Unix => {
295 return Err(CtlError::usage(
296 "--bind unix is not supported on Windows; use --bind loopback (the default)",
297 ));
298 }
299 })
300 }
301}
302
303#[cfg(unix)]
306fn unix_socket_path_for_pid(pid: u32) -> PathBuf {
307 let dir = std::env::var_os("XDG_RUNTIME_DIR")
308 .map(PathBuf::from)
309 .unwrap_or_else(|| PathBuf::from("/tmp"));
310 dir.join(format!("cellctl-webui-{pid}.sock"))
311}
312
313#[cfg(unix)]
317fn bind_unix_listener(path: &std::path::Path) -> CtlResult<tokio::net::UnixListener> {
318 use std::os::unix::fs::PermissionsExt;
319
320 let _ = std::fs::remove_file(path);
322
323 let listener = tokio::net::UnixListener::bind(path)
324 .map_err(|e| CtlError::usage(format!("bind {}: {e}", path.display())))?;
325
326 let perms = std::fs::Permissions::from_mode(0o600);
328 std::fs::set_permissions(path, perms)
329 .map_err(|e| CtlError::usage(format!("chmod 0600 {}: {e}", path.display())))?;
330
331 Ok(listener)
332}
333
334#[cfg(unix)]
340async fn serve_unix(
341 listener: tokio::net::UnixListener,
342 app: Router,
343 shutdown: impl std::future::Future<Output = ()>,
344) -> std::io::Result<()> {
345 use std::convert::Infallible;
346 use tower::Service;
347
348 tokio::pin!(shutdown);
353
354 loop {
355 tokio::select! {
356 _ = &mut shutdown => return Ok(()),
357 accepted = listener.accept() => {
358 let (stream, _peer) = match accepted {
359 Ok(s) => s,
360 Err(e) => {
361 eprintln!("webui: unix accept error: {e}");
363 continue;
364 }
365 };
366 let app = app.clone();
367 tokio::spawn(async move {
368 let svc = hyper::service::service_fn(move |req: http::Request<hyper::body::Incoming>| {
372 let mut router = app.clone();
373 async move {
374 let resp: Response = match router.call(req.map(Body::new)).await {
375 Ok(r) => r,
376 Err(never) => match never {},
377 };
378 Ok::<_, Infallible>(resp)
379 }
380 });
381 let io = hyper_util::rt::TokioIo::new(stream);
382 let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
383 .serve_connection_with_upgrades(io, svc)
384 .await;
385 });
386 }
387 }
388 }
389}
390
391fn mint_session_token() -> String {
393 let mut buf = [0u8; 32];
394 rand::thread_rng().fill_bytes(&mut buf);
395 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf)
396}
397
398fn resolve_bundle_dir() -> CtlResult<PathBuf> {
402 if let Ok(p) = std::env::var("CELLCTL_WEBUI_BUNDLE_DIR") {
403 return Ok(PathBuf::from(p));
404 }
405 let from_manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(BUNDLE_DIR_RELATIVE);
408 Ok(from_manifest)
409}
410
411fn build_router(state: AppState) -> Router {
414 let bundle_dir = state.bundle_dir.as_ref().clone();
417 let serve_dir = ServeDir::new(&bundle_dir).append_index_html_on_directories(true);
418
419 Router::new()
420 .route("/auth/exchange", post(auth_exchange))
421 .route("/v1/*rest", any(proxy_v1))
422 .route("/ws/events", any(ws_events))
423 .fallback_service(serve_dir)
427 .layer(axum::middleware::from_fn(reject_non_get))
428 .with_state(state)
429}
430
431async fn reject_non_get(req: axum::http::Request<Body>, next: axum::middleware::Next) -> Response {
435 let method = req.method().clone();
436 let path = req.uri().path().to_string();
437
438 let is_auth_exchange = method == Method::POST && path == "/auth/exchange";
442 if method != Method::GET && !is_auth_exchange {
443 return method_not_allowed();
444 }
445 next.run(req).await
446}
447
448fn method_not_allowed() -> Response {
449 let mut resp = (StatusCode::METHOD_NOT_ALLOWED, "method not allowed\n").into_response();
450 resp.headers_mut()
451 .insert(header::ALLOW, HeaderValue::from_static("GET"));
452 resp
453}
454
455async fn auth_exchange(
459 State(state): State<AppState>,
460 Json(body): Json<ExchangeRequest>,
461) -> Response {
462 let provided = body.sess.as_bytes();
465 let expected = state.session_token.as_bytes();
466 if provided.len() != expected.len()
467 || !provided
468 .iter()
469 .zip(expected.iter())
470 .fold(true, |acc, (a, b)| acc & (a == b))
471 {
472 return (StatusCode::UNAUTHORIZED, "bad session token\n").into_response();
473 }
474
475 let mut buf = [0u8; 32];
478 rand::thread_rng().fill_bytes(&mut buf);
479 let cookie_value = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf);
480
481 {
482 let mut slot = state.session_cookie.write().await;
483 *slot = Some(cookie_value.clone());
484 }
485
486 let cookie_header = format!(
487 "cellctl_session={}; HttpOnly; SameSite=Strict; Path=/",
488 cookie_value
489 );
490
491 let mut resp = (StatusCode::OK, "ok\n").into_response();
492 resp.headers_mut().insert(
493 header::SET_COOKIE,
494 HeaderValue::from_str(&cookie_header).unwrap(),
495 );
496 resp
497}
498
499async fn require_session_cookie(state: &AppState, headers: &HeaderMap) -> bool {
502 let expected = match state.session_cookie.read().await.clone() {
503 Some(v) => v,
504 None => return false,
505 };
506 let Some(cookie_hdr) = headers.get(header::COOKIE) else {
507 return false;
508 };
509 let Ok(cookie_str) = cookie_hdr.to_str() else {
510 return false;
511 };
512 for entry in cookie_str.split(';') {
514 let entry = entry.trim();
515 if let Some(v) = entry.strip_prefix("cellctl_session=") {
516 return v == expected;
517 }
518 }
519 false
520}
521
522async fn proxy_v1(
524 State(state): State<AppState>,
525 AxumPath(rest): AxumPath<String>,
526 uri: Uri,
527 headers: HeaderMap,
528) -> Response {
529 if !require_session_cookie(&state, &headers).await {
530 return unauthorized();
531 }
532
533 let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
534 let upstream_url = format!("{}/v1/{}{}", state.upstream_base, rest, query);
535
536 let client = match reqwest::Client::builder().build() {
537 Ok(c) => c,
538 Err(e) => return upstream_error(format!("client: {e}")),
539 };
540
541 let mut req = client.get(&upstream_url);
542 if let Some(tok) = state.upstream_bearer.as_ref() {
543 req = req.header(reqwest::header::AUTHORIZATION, format!("Bearer {tok}"));
544 }
545
546 let upstream_resp = match req.send().await {
547 Ok(r) => r,
548 Err(e) => return upstream_error(format!("send: {e}")),
549 };
550
551 let status = upstream_resp.status();
552 let resp_headers = upstream_resp.headers().clone();
553 let body_bytes = match upstream_resp.bytes().await {
554 Ok(b) => b,
555 Err(e) => return upstream_error(format!("read body: {e}")),
556 };
557
558 let mut out = Response::new(Body::from(body_bytes));
559 *out.status_mut() = StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
560 if let Some(ct) = resp_headers.get(reqwest::header::CONTENT_TYPE) {
562 if let Ok(v) = HeaderValue::from_bytes(ct.as_bytes()) {
563 out.headers_mut().insert(header::CONTENT_TYPE, v);
564 }
565 }
566 out
567}
568
569async fn ws_events(
571 State(state): State<AppState>,
572 uri: Uri,
573 headers: HeaderMap,
574 ws: Option<WebSocketUpgrade>,
575) -> Response {
576 if !require_session_cookie(&state, &headers).await {
577 return unauthorized();
578 }
579 let Some(ws) = ws else {
580 return (StatusCode::BAD_REQUEST, "expected websocket upgrade\n").into_response();
581 };
582
583 let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
584 let ws_url = {
585 let base = state.upstream_base.as_str();
587 let ws_base = if let Some(rest) = base.strip_prefix("https://") {
588 format!("wss://{rest}")
589 } else if let Some(rest) = base.strip_prefix("http://") {
590 format!("ws://{rest}")
591 } else {
592 base.to_string()
593 };
594 format!("{ws_base}/ws/events{query}")
595 };
596 let bearer = state.upstream_bearer.as_ref().clone();
597
598 ws.on_upgrade(move |client_ws| async move {
599 let (mut client_tx, mut client_rx) = client_ws.split();
600
601 let mut request =
604 match tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(
605 ws_url.as_str(),
606 ) {
607 Ok(r) => r,
608 Err(_) => return,
609 };
610 if let Some(tok) = bearer {
611 if let Ok(v) = tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&format!(
612 "Bearer {tok}"
613 )) {
614 request.headers_mut().insert(
615 tokio_tungstenite::tungstenite::http::header::AUTHORIZATION,
616 v,
617 );
618 }
619 }
620 let (upstream_ws, _) = match tokio_tungstenite::connect_async(request).await {
621 Ok(p) => p,
622 Err(_) => return,
623 };
624 let (mut up_tx, mut up_rx) = upstream_ws.split();
625
626 loop {
627 tokio::select! {
628 msg = client_rx.next() => match msg {
629 Some(Ok(axum::extract::ws::Message::Text(t))) => {
630 let _ = up_tx
631 .send(tokio_tungstenite::tungstenite::Message::Text(t))
632 .await;
633 }
634 Some(Ok(axum::extract::ws::Message::Binary(b))) => {
635 let _ = up_tx
636 .send(tokio_tungstenite::tungstenite::Message::Binary(b))
637 .await;
638 }
639 Some(Ok(axum::extract::ws::Message::Ping(p))) => {
640 let _ = up_tx
641 .send(tokio_tungstenite::tungstenite::Message::Ping(p))
642 .await;
643 }
644 Some(Ok(axum::extract::ws::Message::Pong(p))) => {
645 let _ = up_tx
646 .send(tokio_tungstenite::tungstenite::Message::Pong(p))
647 .await;
648 }
649 Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
650 Some(Err(_)) => break,
651 },
652 msg = up_rx.next() => match msg {
653 Some(Ok(tokio_tungstenite::tungstenite::Message::Text(t))) => {
654 let _ = client_tx
655 .send(axum::extract::ws::Message::Text(t))
656 .await;
657 }
658 Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(b))) => {
659 let _ = client_tx
660 .send(axum::extract::ws::Message::Binary(b))
661 .await;
662 }
663 Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(p))) => {
664 let _ = client_tx
665 .send(axum::extract::ws::Message::Ping(p))
666 .await;
667 }
668 Some(Ok(tokio_tungstenite::tungstenite::Message::Pong(p))) => {
669 let _ = client_tx
670 .send(axum::extract::ws::Message::Pong(p))
671 .await;
672 }
673 Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => break,
674 Some(Ok(tokio_tungstenite::tungstenite::Message::Frame(_))) => {}
675 Some(Err(_)) => break,
676 },
677 }
678 }
679 })
680}
681
682fn unauthorized() -> Response {
683 (StatusCode::UNAUTHORIZED, "missing session cookie\n").into_response()
684}
685
686fn upstream_error(msg: String) -> Response {
687 (StatusCode::BAD_GATEWAY, format!("upstream: {msg}\n")).into_response()
688}
689
690fn state_upstream_for_log(_app: &Router) -> &'static str {
694 "(see state)"
699}
700
701#[allow(dead_code)]
702const _BYTES_KEEP: fn() = || {
703 let _: Bytes = Bytes::new();
706 let _: HeaderName = HeaderName::from_static("x-keep");
707};
708
709#[cfg(test)]
714mod tests {
715 use super::*;
716 use axum::body::to_bytes;
717 use axum::http::Request;
718 use tower::ServiceExt; fn test_state(bundle_dir: PathBuf) -> AppState {
721 AppState {
722 upstream_base: Arc::new("http://127.0.0.1:0".to_string()),
723 upstream_bearer: Arc::new(None),
724 session_token: Arc::new("test-token".to_string()),
725 session_cookie: Arc::new(RwLock::new(None)),
726 bundle_dir: Arc::new(bundle_dir),
727 }
728 }
729
730 fn ensure_bundle_dir() -> PathBuf {
733 let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
734 let real = manifest_dir.join("static");
735 if real.join("index.html").exists() {
736 return real;
737 }
738 let tmp = std::env::temp_dir().join(format!("cellctl-webui-test-{}", std::process::id()));
740 std::fs::create_dir_all(&tmp).expect("mkdir tmp bundle");
741 std::fs::write(
742 tmp.join("index.html"),
743 "<!doctype html><title>cellctl webui</title>",
744 )
745 .expect("write index.html");
746 tmp
747 }
748
749 #[tokio::test]
750 async fn serves_index_at_root() {
751 let bundle = ensure_bundle_dir();
752 let app = build_router(test_state(bundle));
753
754 let resp = app
755 .oneshot(
756 Request::builder()
757 .method("GET")
758 .uri("/")
759 .body(Body::empty())
760 .unwrap(),
761 )
762 .await
763 .unwrap();
764 assert_eq!(resp.status(), StatusCode::OK);
765 let body = to_bytes(resp.into_body(), 64 * 1024).await.unwrap();
766 let body_str = std::str::from_utf8(&body).unwrap();
767 assert!(
768 body_str.to_ascii_lowercase().contains("<!doctype html")
769 || body_str.to_ascii_lowercase().contains("<html"),
770 "expected HTML at /, got: {body_str:.200}"
771 );
772 }
773
774 #[tokio::test]
775 async fn non_get_returns_405() {
776 let bundle = ensure_bundle_dir();
777 let app = build_router(test_state(bundle));
778
779 let resp = app
780 .oneshot(
781 Request::builder()
782 .method("DELETE")
783 .uri("/v1/formations")
784 .body(Body::empty())
785 .unwrap(),
786 )
787 .await
788 .unwrap();
789 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
790 assert_eq!(
791 resp.headers().get(header::ALLOW).map(|v| v.as_bytes()),
792 Some(b"GET" as &[u8]),
793 );
794 }
795
796 #[tokio::test]
797 async fn put_to_v1_returns_405() {
798 let bundle = ensure_bundle_dir();
799 let app = build_router(test_state(bundle));
800
801 let resp = app
802 .oneshot(
803 Request::builder()
804 .method("PUT")
805 .uri("/v1/formations/foo")
806 .body(Body::from("body"))
807 .unwrap(),
808 )
809 .await
810 .unwrap();
811 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
812 }
813
814 #[tokio::test]
815 async fn proxy_without_cookie_returns_401() {
816 let bundle = ensure_bundle_dir();
817 let app = build_router(test_state(bundle));
818
819 let resp = app
820 .oneshot(
821 Request::builder()
822 .method("GET")
823 .uri("/v1/formations")
824 .body(Body::empty())
825 .unwrap(),
826 )
827 .await
828 .unwrap();
829 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
830 }
831
832 #[tokio::test]
833 async fn auth_exchange_with_wrong_token_returns_401() {
834 let bundle = ensure_bundle_dir();
835 let app = build_router(test_state(bundle));
836
837 let resp = app
838 .oneshot(
839 Request::builder()
840 .method("POST")
841 .uri("/auth/exchange")
842 .header(header::CONTENT_TYPE, "application/json")
843 .body(Body::from(r#"{"sess":"wrong"}"#))
844 .unwrap(),
845 )
846 .await
847 .unwrap();
848 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
849 }
850
851 #[tokio::test]
852 async fn auth_exchange_with_right_token_sets_cookie() {
853 let bundle = ensure_bundle_dir();
854 let state = test_state(bundle);
855 let app = build_router(state.clone());
856
857 let resp = app
858 .oneshot(
859 Request::builder()
860 .method("POST")
861 .uri("/auth/exchange")
862 .header(header::CONTENT_TYPE, "application/json")
863 .body(Body::from(r#"{"sess":"test-token"}"#))
864 .unwrap(),
865 )
866 .await
867 .unwrap();
868 assert_eq!(resp.status(), StatusCode::OK);
869 let cookie = resp
870 .headers()
871 .get(header::SET_COOKIE)
872 .expect("Set-Cookie header present")
873 .to_str()
874 .unwrap()
875 .to_string();
876 assert!(cookie.starts_with("cellctl_session="));
877 assert!(cookie.contains("HttpOnly"));
878 assert!(cookie.contains("SameSite=Strict"));
879
880 let stored = state.session_cookie.read().await.clone();
882 assert!(stored.is_some());
883 let stored = stored.unwrap();
884 assert!(cookie.contains(&format!("cellctl_session={stored}")));
885 }
886
887 #[tokio::test]
888 async fn proxy_with_valid_cookie_attempts_upstream() {
889 let bundle = ensure_bundle_dir();
894 let state = test_state(bundle);
895 let app = build_router(state.clone());
896
897 let exch = app
898 .clone()
899 .oneshot(
900 Request::builder()
901 .method("POST")
902 .uri("/auth/exchange")
903 .header(header::CONTENT_TYPE, "application/json")
904 .body(Body::from(r#"{"sess":"test-token"}"#))
905 .unwrap(),
906 )
907 .await
908 .unwrap();
909 assert_eq!(exch.status(), StatusCode::OK);
910 let cookie_hdr = exch
911 .headers()
912 .get(header::SET_COOKIE)
913 .unwrap()
914 .to_str()
915 .unwrap()
916 .to_string();
917 let cookie_pair = cookie_hdr.split(';').next().unwrap().trim().to_string();
919
920 let resp = app
921 .oneshot(
922 Request::builder()
923 .method("GET")
924 .uri("/v1/formations")
925 .header(header::COOKIE, cookie_pair)
926 .body(Body::empty())
927 .unwrap(),
928 )
929 .await
930 .unwrap();
931 assert_ne!(resp.status(), StatusCode::UNAUTHORIZED);
933 }
934
935 #[tokio::test]
947 async fn proxy_only_emits_get_to_upstream() {
948 use std::sync::Mutex;
949 use tokio::io::{AsyncReadExt, AsyncWriteExt};
950 use tokio::net::TcpListener;
951
952 let mock = TcpListener::bind("127.0.0.1:0").await.expect("bind mock");
954 let mock_addr = mock.local_addr().expect("mock addr");
955 let methods_seen: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
956 let methods_for_task = methods_seen.clone();
957
958 tokio::spawn(async move {
965 loop {
966 let (mut stream, _) = match mock.accept().await {
967 Ok(p) => p,
968 Err(_) => return,
969 };
970 let methods = methods_for_task.clone();
971 tokio::spawn(async move {
972 let mut buf = [0u8; 4096];
973 let n = match stream.read(&mut buf).await {
974 Ok(n) => n,
975 Err(_) => return,
976 };
977 let head = String::from_utf8_lossy(&buf[..n]).to_string();
978 if let Some(first_line) = head.lines().next() {
979 if let Some(method) = first_line.split_whitespace().next() {
980 methods.lock().unwrap().push(method.to_string());
981 }
982 }
983 let _ = stream
984 .write_all(
985 b"HTTP/1.1 200 OK\r\n\
986 Content-Type: application/json\r\n\
987 Content-Length: 2\r\n\
988 Connection: close\r\n\
989 \r\n\
990 {}",
991 )
992 .await;
993 let _ = stream.shutdown().await;
994 });
995 }
996 });
997
998 let bundle = ensure_bundle_dir();
1000 let state = AppState {
1001 upstream_base: Arc::new(format!("http://{}", mock_addr)),
1002 upstream_bearer: Arc::new(Some("test-bearer".to_string())),
1003 session_token: Arc::new("test-token".to_string()),
1004 session_cookie: Arc::new(RwLock::new(None)),
1005 bundle_dir: Arc::new(bundle),
1006 };
1007 let app = build_router(state.clone());
1008
1009 let exch = app
1013 .clone()
1014 .oneshot(
1015 Request::builder()
1016 .method("POST")
1017 .uri("/auth/exchange")
1018 .header(header::CONTENT_TYPE, "application/json")
1019 .body(Body::from(r#"{"sess":"test-token"}"#))
1020 .unwrap(),
1021 )
1022 .await
1023 .unwrap();
1024 assert_eq!(exch.status(), StatusCode::OK);
1025 let cookie_pair = exch
1026 .headers()
1027 .get(header::SET_COOKIE)
1028 .unwrap()
1029 .to_str()
1030 .unwrap()
1031 .split(';')
1032 .next()
1033 .unwrap()
1034 .trim()
1035 .to_string();
1036
1037 let non_get_methods = ["POST", "PUT", "DELETE", "PATCH"];
1040 for m in non_get_methods {
1041 let resp = app
1042 .clone()
1043 .oneshot(
1044 Request::builder()
1045 .method(m)
1046 .uri("/v1/formations")
1047 .header(header::COOKIE, cookie_pair.clone())
1048 .body(Body::from("payload that must never reach upstream"))
1049 .unwrap(),
1050 )
1051 .await
1052 .unwrap();
1053 assert_eq!(
1054 resp.status(),
1055 StatusCode::METHOD_NOT_ALLOWED,
1056 "inbound {m} should be 405"
1057 );
1058 assert_eq!(
1059 resp.headers().get(header::ALLOW).map(|v| v.as_bytes()),
1060 Some(b"GET" as &[u8]),
1061 "405 response for {m} must carry `Allow: GET`"
1062 );
1063 }
1064
1065 let get_resp = app
1067 .clone()
1068 .oneshot(
1069 Request::builder()
1070 .method("GET")
1071 .uri("/v1/formations")
1072 .header(header::COOKIE, cookie_pair.clone())
1073 .body(Body::empty())
1074 .unwrap(),
1075 )
1076 .await
1077 .unwrap();
1078 assert_ne!(
1080 get_resp.status(),
1081 StatusCode::METHOD_NOT_ALLOWED,
1082 "GET must pass the 405 middleware"
1083 );
1084
1085 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1087
1088 let observed = methods_seen.lock().unwrap().clone();
1090 let non_get_count = observed.iter().filter(|m| m.as_str() != "GET").count();
1093 assert_eq!(
1094 non_get_count, 0,
1095 "upstream saw non-GET method(s): {observed:?}"
1096 );
1097 assert!(
1098 observed.iter().any(|m| m == "GET"),
1099 "expected at least one GET to reach upstream, saw {observed:?}"
1100 );
1101 }
1102}