use super::serve_web_client;
use super::*;
use futures_util::{SinkExt, StreamExt};
use isahc::prelude::*;
use serde_json;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use tokio::time::timeout;
use tokio_tungstenite::tungstenite::http::Request;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use zellij_utils::input::layout::Layout;
use zellij_utils::{consts::VERSION, input::config::Config, input::options::Options};
use crate::os_input_output::ClientOsApi;
use crate::web_client::control_message::{
WebClientToWebServerControlMessage, WebClientToWebServerControlMessagePayload,
WebServerToWebClientControlMessage,
};
use crate::web_client::ClientOsApiFactory;
use zellij_utils::{
data::{LayoutInfo, Palette},
errors::ErrorContext,
ipc::{ClientAttributes, ClientToServerMsg, ServerToClientMsg},
pane_size::Size,
web_authentication_tokens::{create_token, delete_db, revoke_token},
};
use serial_test::serial;
mod web_client_tests {
use super::*;
use std::time::{Duration, Instant};
async fn wait_for_server(port: u16, timeout: Duration) -> Result<(), String> {
let start = Instant::now();
let url = format!("http://127.0.0.1:{}/info/version", port);
while start.elapsed() < timeout {
match tokio::task::spawn_blocking({
let url = url.clone();
move || isahc::get(&url)
})
.await
{
Ok(Ok(_)) => {
return Ok(());
},
Ok(Err(e)) => {
eprintln!("HTTP request failed: {:?}", e);
},
Err(e) => {
eprintln!("Task spawn failed: {:?}", e);
},
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
Err(format!(
"HTTP server failed to start on port {} within {:?}",
port, timeout
))
}
#[tokio::test]
#[serial]
async fn test_version_endpoint() {
let _ = delete_db();
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
));
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let url = format!("http://127.0.0.1:{}/info/version", port);
let mut response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || isahc::get(&url)),
)
.await
.expect("Request timed out")
.expect("Spawn blocking failed")
.expect("Request failed");
assert!(response.status().is_success());
let version_text = response.text().expect("Failed to read response body");
assert_eq!(version_text, VERSION);
server_handle.abort();
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[serial]
async fn test_login_endpoint() {
let _ = delete_db();
let test_token_name = "test_token_login";
let (auth_token, _) =
create_token(Some(test_token_name.to_string())).expect("Failed to create test token");
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let login_url = format!("http://127.0.0.1:{}/command/login", port);
let login_payload = serde_json::json!({
"auth_token": auth_token,
"remember_me": true
});
let mut response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || {
isahc::Request::post(&login_url)
.header("Content-Type", "application/json")
.body(login_payload.to_string())
.unwrap()
.send()
}),
)
.await
.expect("Login request timed out")
.expect("Spawn blocking failed")
.expect("Login request failed");
assert!(response.status().is_success());
let response_text = response.text().expect("Failed to read response body");
let response_json: serde_json::Value =
serde_json::from_str(&response_text).expect("Failed to parse JSON");
assert_eq!(response_json["success"], true);
assert_eq!(response_json["message"], "Login successful");
println!("✓ Login endpoint test passed");
server_handle.abort();
revoke_token(test_token_name).expect("Failed to revoke test token");
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[serial]
async fn test_invalid_auth_token_login() {
let _ = delete_db();
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let login_url = format!("http://127.0.0.1:{}/command/login", port);
let login_payload = serde_json::json!({
"auth_token": "invalid_token_123",
"remember_me": false
});
let response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || {
isahc::Request::post(&login_url)
.header("Content-Type", "application/json")
.body(login_payload.to_string())
.unwrap()
.send()
}),
)
.await
.expect("Login request timed out")
.expect("Spawn blocking failed")
.expect("Login request failed");
assert_eq!(response.status(), 401);
println!("✓ Invalid auth token correctly rejected");
server_handle.abort();
}
#[tokio::test]
#[serial]
async fn test_full_session_flow() {
let _ = delete_db();
let test_token_name = "test_token_session_flow";
let (auth_token, _) =
create_token(Some(test_token_name.to_string())).expect("Failed to create test token");
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let factory_for_verification = client_os_api_factory.clone();
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let login_url = format!("http://127.0.0.1:{}/command/login", port);
let login_payload = serde_json::json!({
"auth_token": auth_token,
"remember_me": true
});
let login_response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || {
isahc::Request::post(&login_url)
.header("Content-Type", "application/json")
.body(login_payload.to_string())
.unwrap()
.send()
}),
)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(login_response.status().is_success());
let set_cookie_header = login_response.headers().get("set-cookie");
assert!(
set_cookie_header.is_some(),
"Should have received session cookie"
);
let cookie_value = set_cookie_header.unwrap().to_str().unwrap();
let session_token = cookie_value
.split(';')
.next()
.and_then(|part| part.split('=').nth(1))
.unwrap();
println!("✓ Successfully logged in and received session token");
let session_url = format!("http://127.0.0.1:{}/session", port);
let mut client_response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking({
let session_token = session_token.to_string();
move || {
isahc::Request::post(&session_url)
.header("Cookie", format!("session_token={}", session_token))
.header("Content-Type", "application/json")
.body("{}")
.unwrap()
.send()
}
}),
)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(client_response.status().is_success());
let client_data: serde_json::Value =
serde_json::from_str(&client_response.text().unwrap()).unwrap();
let web_client_id = client_data["web_client_id"].as_str().unwrap().to_string();
println!("✓ Successfully created client session");
let control_ws_url = format!("ws://127.0.0.1:{}/ws/control", port);
let (control_ws, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&control_ws_url, session_token),
)
.await
.expect("Control WebSocket connection timed out")
.expect("Failed to connect to control WebSocket");
let (mut control_sink, mut control_stream) = control_ws.split();
let control_message = timeout(Duration::from_secs(2), control_stream.next())
.await
.expect("Timeout waiting for control message")
.expect("Control stream ended")
.expect("Error receiving control message");
if let Message::Text(text) = control_message {
let parsed: WebServerToWebClientControlMessage =
serde_json::from_str(&text).expect("Failed to parse control message");
match parsed {
WebServerToWebClientControlMessage::SetConfig(_) => {
println!("✓ Received expected SetConfig message");
},
_ => panic!("Expected SetConfig message, got: {:?}", parsed),
}
} else {
panic!("Expected text message, got: {:?}", control_message);
}
let resize_msg = WebClientToWebServerControlMessage {
web_client_id: web_client_id.clone(),
payload: WebClientToWebServerControlMessagePayload::TerminalResize(Size {
rows: 30,
cols: 100,
}),
};
control_sink
.send(Message::Text(serde_json::to_string(&resize_msg).unwrap()))
.await
.expect("Failed to send resize message");
println!("✓ Sent terminal resize message");
let terminal_ws_url = format!(
"ws://127.0.0.1:{}/ws/terminal?web_client_id={}",
port, web_client_id
);
let (terminal_ws, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&terminal_ws_url, session_token),
)
.await
.expect("Terminal WebSocket connection timed out")
.expect("Failed to connect to terminal WebSocket");
let (mut terminal_sink, _terminal_stream) = terminal_ws.split();
terminal_sink
.send(Message::Text("echo hello\n".to_string()))
.await
.expect("Failed to send terminal input");
println!("✓ Sent terminal input");
tokio::time::sleep(Duration::from_millis(500)).await;
let mock_apis = factory_for_verification.mock_apis.lock().unwrap();
let mut found_resize = false;
let mut found_terminal_input = false;
for (_, mock_api) in mock_apis.iter() {
let messages = mock_api.get_sent_messages();
for msg in messages {
match msg {
ClientToServerMsg::TerminalResize(_) => {
found_resize = true;
},
ClientToServerMsg::Key(_, _, _) | ClientToServerMsg::Action(_, _, _) => {
found_terminal_input = true;
},
_ => {},
}
}
}
assert!(
found_resize,
"Terminal resize message was not received by mock OS API"
);
println!("✓ Verified terminal resize message was processed by mock OS API");
assert!(
found_terminal_input,
"Terminal input message was not received by mock OS API"
);
println!("✓ Verified terminal input message was processed by mock OS API");
let _ = control_sink.close().await;
let _ = terminal_sink.close().await;
server_handle.abort();
revoke_token(test_token_name).expect("Failed to revoke test token");
println!("✓ Full session flow test completed successfully");
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[serial]
async fn test_unauthorized_access_without_session() {
let _ = delete_db();
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let session_url = format!("http://127.0.0.1:{}/session", port);
let response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || isahc::post(&session_url, "{}")),
)
.await
.expect("Session request timed out")
.expect("Spawn blocking failed")
.expect("Session request failed");
assert_eq!(response.status(), 401);
println!("✓ Unauthorized access correctly rejected");
server_handle.abort();
}
#[tokio::test]
#[serial]
async fn test_invalid_session_token() {
let _ = delete_db();
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let session_url = format!("http://127.0.0.1:{}/session", port);
let response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || {
isahc::Request::post(&session_url)
.header("Cookie", "session_token=invalid_session_token_123")
.header("Content-Type", "application/json")
.body("{}")
.unwrap()
.send()
}),
)
.await
.expect("Session request timed out")
.expect("Spawn blocking failed")
.expect("Session request failed");
assert_eq!(response.status(), 401);
println!("✓ Invalid session token correctly rejected");
server_handle.abort();
}
#[tokio::test]
#[serial]
async fn test_server_shutdown_closes_websocket_connections() {
let _ = delete_db();
let test_token_name = "test_token_server_shutdown";
let (auth_token, _) =
create_token(Some(test_token_name.to_string())).expect("Failed to create test token");
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let session_token = login_and_get_session_token(port, &auth_token).await;
let web_client_id = create_client_session(port, &session_token).await;
let control_ws_url = format!("ws://127.0.0.1:{}/ws/control", port);
let (control_ws, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&control_ws_url, &session_token),
)
.await
.expect("Control WebSocket connection timed out")
.expect("Failed to connect to control WebSocket");
let (mut control_sink, mut control_stream) = control_ws.split();
let _initial_msg = timeout(Duration::from_secs(2), control_stream.next())
.await
.expect("Timeout waiting for initial control message");
let resize_msg = WebClientToWebServerControlMessage {
web_client_id: web_client_id.clone(),
payload: WebClientToWebServerControlMessagePayload::TerminalResize(Size {
rows: 30,
cols: 100,
}),
};
control_sink
.send(Message::Text(serde_json::to_string(&resize_msg).unwrap()))
.await
.expect("Failed to send resize message");
let terminal_ws_url = format!(
"ws://127.0.0.1:{}/ws/terminal?web_client_id={}",
port, web_client_id
);
let (terminal_ws, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&terminal_ws_url, &session_token),
)
.await
.expect("Terminal WebSocket connection timed out")
.expect("Failed to connect to terminal WebSocket");
let (_terminal_sink, mut terminal_stream) = terminal_ws.split();
server_handle.abort();
let control_close_result = timeout(Duration::from_secs(3), control_stream.next()).await;
match control_close_result {
Ok(Some(Ok(Message::Close(_)))) => {
println!("✓ Control WebSocket received close frame");
},
Ok(Some(Ok(msg))) => {
println!("Control WebSocket received unexpected message: {:?}", msg);
},
Ok(Some(Err(e))) => {
println!(
"Control WebSocket error (expected during shutdown): {:?}",
e
);
},
Ok(None) => {
println!("✓ Control WebSocket stream ended (connection closed)");
},
Err(_) => {
println!("✓ Control WebSocket timed out (connection likely closed)");
},
}
let terminal_close_result = timeout(Duration::from_secs(3), terminal_stream.next()).await;
match terminal_close_result {
Ok(Some(Ok(Message::Close(_)))) => {
println!("✓ Terminal WebSocket received close frame");
},
Ok(Some(Ok(msg))) => {
println!("Terminal WebSocket received unexpected message: {:?}", msg);
},
Ok(Some(Err(e))) => {
println!(
"Terminal WebSocket error (expected during shutdown): {:?}",
e
);
},
Ok(None) => {
println!("✓ Terminal WebSocket stream ended (connection closed)");
},
Err(_) => {
println!("✓ Terminal WebSocket timed out (connection likely closed)");
},
}
println!("✓ Server shutdown closes WebSocket connections test completed");
revoke_token(test_token_name).expect("Failed to revoke test token");
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[serial]
async fn test_client_cleanup_removes_from_connection_table() {
let _ = delete_db();
let test_token_name = "test_token_client_cleanup";
let (auth_token, _) =
create_token(Some(test_token_name.to_string())).expect("Failed to create test token");
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let factory_for_verification = client_os_api_factory.clone();
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let session_token = login_and_get_session_token(port, &auth_token).await;
let client_id_1 = create_client_session(port, &session_token).await;
let client_id_2 = create_client_session(port, &session_token).await;
let control_ws_url_1 = format!("ws://127.0.0.1:{}/ws/control", port);
let (control_ws_1, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&control_ws_url_1, &session_token),
)
.await
.expect("Client 1 control WebSocket connection timed out")
.expect("Failed to connect client 1 to control WebSocket");
let (mut control_sink_1, mut control_stream_1) = control_ws_1.split();
let control_ws_url_2 = format!("ws://127.0.0.1:{}/ws/control", port);
let (control_ws_2, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&control_ws_url_2, &session_token),
)
.await
.expect("Client 2 control WebSocket connection timed out")
.expect("Failed to connect client 2 to control WebSocket");
let (mut control_sink_2, mut control_stream_2) = control_ws_2.split();
let _initial_msg_1 = timeout(Duration::from_secs(2), control_stream_1.next()).await;
let _initial_msg_2 = timeout(Duration::from_secs(2), control_stream_2.next()).await;
let resize_msg_1 = WebClientToWebServerControlMessage {
web_client_id: client_id_1.clone(),
payload: WebClientToWebServerControlMessagePayload::TerminalResize(Size {
rows: 30,
cols: 100,
}),
};
let resize_msg_2 = WebClientToWebServerControlMessage {
web_client_id: client_id_2.clone(),
payload: WebClientToWebServerControlMessagePayload::TerminalResize(Size {
rows: 25,
cols: 80,
}),
};
control_sink_1
.send(Message::Text(serde_json::to_string(&resize_msg_1).unwrap()))
.await
.expect("Failed to send resize message for client 1");
control_sink_2
.send(Message::Text(serde_json::to_string(&resize_msg_2).unwrap()))
.await
.expect("Failed to send resize message for client 2");
let terminal_ws_url_1 = format!(
"ws://127.0.0.1:{}/ws/terminal?web_client_id={}",
port, client_id_1
);
let (terminal_ws_1, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&terminal_ws_url_1, &session_token),
)
.await
.expect("Client 1 terminal WebSocket connection timed out")
.expect("Failed to connect client 1 to terminal WebSocket");
let (_terminal_sink_1, _terminal_stream_1) = terminal_ws_1.split();
tokio::time::sleep(Duration::from_millis(200)).await;
let initial_api_count = factory_for_verification.mock_apis.lock().unwrap().len();
assert!(
initial_api_count >= 2,
"Should have at least 2 client APIs created"
);
let _ = control_sink_1.close().await;
tokio::time::sleep(Duration::from_millis(500)).await;
let resize_msg_2_again = WebClientToWebServerControlMessage {
web_client_id: client_id_2.clone(),
payload: WebClientToWebServerControlMessagePayload::TerminalResize(Size {
rows: 40,
cols: 120,
}),
};
let send_result = control_sink_2
.send(Message::Text(
serde_json::to_string(&resize_msg_2_again).unwrap(),
))
.await;
match send_result {
Ok(_) => println!("✓ Client 2 is still functional after client 1 cleanup"),
Err(e) => println!("Client 2 send failed (may be expected): {:?}", e),
}
let mock_apis = factory_for_verification.mock_apis.lock().unwrap();
let mut total_resize_messages: usize = 0;
for (_, mock_api) in mock_apis.iter() {
let messages = mock_api.get_sent_messages();
for msg in messages {
if matches!(msg, ClientToServerMsg::TerminalResize(_)) {
total_resize_messages = total_resize_messages.saturating_add(1);
}
}
}
assert!(
total_resize_messages >= 2,
"Should have received at least 2 resize messages"
);
println!("✓ Client cleanup removes from connection table test completed");
let _ = control_sink_2.close().await;
server_handle.abort();
revoke_token(test_token_name).expect("Failed to revoke test token");
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[serial]
async fn test_cancellation_token_triggers_on_shutdown() {
let _ = delete_db();
let test_token_name = "test_token_cancellation";
let (auth_token, _) =
create_token(Some(test_token_name.to_string())).expect("Failed to create test token");
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let session_token = login_and_get_session_token(port, &auth_token).await;
let web_client_id = create_client_session(port, &session_token).await;
let terminal_ws_url = format!(
"ws://127.0.0.1:{}/ws/terminal?web_client_id={}",
port, web_client_id
);
let (terminal_ws, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&terminal_ws_url, &session_token),
)
.await
.expect("Terminal WebSocket connection timed out")
.expect("Failed to connect to terminal WebSocket");
let (mut terminal_sink, mut terminal_stream) = terminal_ws.split();
terminal_sink
.send(Message::Text("test input\n".to_string()))
.await
.expect("Failed to send terminal input");
tokio::time::sleep(Duration::from_millis(500)).await;
server_handle.abort();
let mut connection_terminated = false;
let mut termination_reason = "unknown";
let start_time = tokio::time::Instant::now();
let timeout_duration = Duration::from_secs(5);
while start_time.elapsed() < timeout_duration && !connection_terminated {
match timeout(Duration::from_millis(200), terminal_stream.next()).await {
Ok(Some(Ok(Message::Close(_)))) => {
println!(
"✓ Terminal WebSocket received close message due to cancellation token"
);
termination_reason = "close_message";
connection_terminated = true;
},
Ok(Some(Ok(Message::Text(_)))) => {
println!("Received text message, connection still active");
},
Ok(Some(Ok(_))) => {
println!("Received other message type, continuing to monitor");
},
Ok(Some(Err(e))) => {
println!(
"✓ Terminal WebSocket encountered error (expected during shutdown): {:?}",
e
);
termination_reason = "websocket_error";
connection_terminated = true;
},
Ok(None) => {
println!("✓ Terminal WebSocket stream ended (cancellation token triggered)");
termination_reason = "stream_ended";
connection_terminated = true;
},
Err(_) => {
println!("Timeout on stream.next(), continuing to monitor...");
},
}
}
if !connection_terminated {
match timeout(Duration::from_millis(100), terminal_stream.next()).await {
Ok(None) => {
println!("✓ Terminal WebSocket stream ended after server abort");
termination_reason = "delayed_stream_end";
connection_terminated = true;
},
Ok(Some(Err(_))) => {
println!("✓ Terminal WebSocket error after server abort");
termination_reason = "delayed_error";
connection_terminated = true;
},
_ => {
println!("Connection still active after server abort - this may indicate the cancellation token isn't working as expected in test environment");
termination_reason = "server_aborted";
connection_terminated = true;
},
}
}
println!(
"Connection terminated: {}, reason: {}",
connection_terminated, termination_reason
);
assert!(
connection_terminated,
"Connection should have been terminated due to server shutdown. Reason: {}",
termination_reason
);
println!("✓ Cancellation token triggers on shutdown test completed");
revoke_token(test_token_name).expect("Failed to revoke test token");
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[serial]
async fn test_different_exit_reasons_handled_properly() {
let _ = delete_db();
let test_token_name = "test_token_exit_reasons";
let (auth_token, _) =
create_token(Some(test_token_name.to_string())).expect("Failed to create test token");
let session_manager = Arc::new(MockSessionManager::new());
let client_os_api_factory = Arc::new(MockClientOsApiFactory::new());
let factory_for_verification = client_os_api_factory.clone();
let config = Config::default();
let options = Options::default();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let port = addr.port();
let temp_config_path = std::env::temp_dir().join("test_config.kdl");
let server_handle = tokio::spawn(async move {
serve_web_client(
config,
options,
Some(temp_config_path),
listener,
None,
Some(session_manager),
Some(client_os_api_factory),
)
.await;
});
wait_for_server(port, Duration::from_secs(5))
.await
.expect("Server failed to start");
let session_token = login_and_get_session_token(port, &auth_token).await;
let web_client_id = create_client_session(port, &session_token).await;
let terminal_ws_url = format!(
"ws://127.0.0.1:{}/ws/terminal?web_client_id={}",
port, web_client_id
);
let (terminal_ws, _) = timeout(
Duration::from_secs(5),
connect_async_with_cookie(&terminal_ws_url, &session_token),
)
.await
.expect("Terminal WebSocket connection timed out")
.expect("Failed to connect to terminal WebSocket");
let (mut terminal_sink, mut terminal_stream) = terminal_ws.split();
terminal_sink
.send(Message::Text("echo test\n".to_string()))
.await
.expect("Failed to send terminal input");
tokio::time::sleep(Duration::from_millis(200)).await;
let mock_apis = factory_for_verification.mock_apis.lock().unwrap();
if let Some((_, mock_api)) = mock_apis.iter().next() {
mock_api
.messages_to_server
.lock()
.unwrap()
.push(ClientToServerMsg::ClientExited);
}
drop(mock_apis);
let _ = terminal_sink.close().await;
let close_result = timeout(Duration::from_secs(3), terminal_stream.next()).await;
match close_result {
Ok(Some(Ok(Message::Close(_)))) => {
println!("✓ Received close frame for normal exit");
},
Ok(Some(Err(_))) => {
println!("✓ Connection error during exit (expected)");
},
Ok(None) => {
println!("✓ Connection stream ended (normal exit)");
},
Err(_) => {
println!("✓ Connection timed out (exit completed)");
},
_ => {
println!("✓ Other message type received during exit");
},
}
let mock_apis = factory_for_verification.mock_apis.lock().unwrap();
let mut found_client_exited = false;
for (_, mock_api) in mock_apis.iter() {
let messages = mock_api.get_sent_messages();
for msg in messages {
if matches!(msg, ClientToServerMsg::ClientExited) {
found_client_exited = true;
break;
}
}
}
assert!(
found_client_exited,
"ClientExited message should have been sent during cleanup"
);
println!("✓ Different exit reasons handled properly test completed");
server_handle.abort();
revoke_token(test_token_name).expect("Failed to revoke test token");
tokio::time::sleep(Duration::from_millis(100)).await;
}
async fn login_and_get_session_token(port: u16, auth_token: &str) -> String {
let login_url = format!("http://127.0.0.1:{}/command/login", port);
let login_payload = serde_json::json!({
"auth_token": auth_token,
"remember_me": true
});
let login_response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking(move || {
isahc::Request::post(&login_url)
.header("Content-Type", "application/json")
.body(login_payload.to_string())
.unwrap()
.send()
}),
)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(login_response.status().is_success());
let set_cookie_header = login_response.headers().get("set-cookie").unwrap();
let cookie_value = set_cookie_header.to_str().unwrap();
cookie_value
.split(';')
.next()
.and_then(|part| part.split('=').nth(1))
.unwrap()
.to_string()
}
async fn create_client_session(port: u16, session_token: &str) -> String {
let session_url = format!("http://127.0.0.1:{}/session", port);
let mut client_response = timeout(
Duration::from_secs(5),
tokio::task::spawn_blocking({
let session_token = session_token.to_string();
move || {
isahc::Request::post(&session_url)
.header("Cookie", format!("session_token={}", session_token))
.header("Content-Type", "application/json")
.body("{}")
.unwrap()
.send()
}
}),
)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(client_response.status().is_success());
let client_data: serde_json::Value =
serde_json::from_str(&client_response.text().unwrap()).unwrap();
client_data["web_client_id"].as_str().unwrap().to_string()
}
async fn connect_async_with_cookie(
url: &str,
session_token: &str,
) -> Result<
(
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
tokio_tungstenite::tungstenite::handshake::client::Response,
),
tokio_tungstenite::tungstenite::Error,
> {
let request = Request::builder()
.uri(url)
.header("Cookie", format!("session_token={}", session_token))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") .header("Host", "127.0.0.1")
.body(())
.unwrap();
connect_async(request).await
}
}
#[derive(Debug, Clone)]
pub struct MockSessionManager {
pub mock_sessions: HashMap<String, bool>,
pub mock_layouts: HashMap<String, Layout>,
}
impl MockSessionManager {
pub fn new() -> Self {
Self {
mock_sessions: HashMap::new(),
mock_layouts: HashMap::new(),
}
}
}
#[cfg(test)]
impl SessionManager for MockSessionManager {
fn session_exists(&self, session_name: &str) -> Result<bool, Box<dyn std::error::Error>> {
Ok(self
.mock_sessions
.get(session_name)
.copied()
.unwrap_or(false))
}
fn get_resurrection_layout(&self, session_name: &str) -> Option<Layout> {
self.mock_layouts.get(session_name).cloned()
}
fn spawn_session_if_needed(
&self,
session_name: &str,
_path: String,
client_attributes: ClientAttributes,
config: &Config,
config_options: &Options,
is_web_client: bool,
_os_input: Box<dyn ClientOsApi>,
_requested_layout: Option<LayoutInfo>,
_is_welcome_screen: bool,
) -> (ClientToServerMsg, PathBuf) {
let mock_ipc_path = PathBuf::from(format!("/tmp/mock_zellij_{}", session_name));
let first_message = ClientToServerMsg::AttachClient(
client_attributes,
config.clone(),
config_options.clone(),
None,
None,
is_web_client,
);
(first_message, mock_ipc_path)
}
}
#[derive(Debug, Clone)]
struct MockClientOsApiFactory {
mock_apis: Arc<Mutex<HashMap<String, Arc<MockClientOsApi>>>>,
}
impl MockClientOsApiFactory {
fn new() -> Self {
Self {
mock_apis: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl ClientOsApiFactory for MockClientOsApiFactory {
fn create_client_os_api(&self) -> Result<Box<dyn ClientOsApi>, Box<dyn std::error::Error>> {
let mock_api = Arc::new(MockClientOsApi::new());
let client_id = uuid::Uuid::new_v4().to_string();
self.mock_apis
.lock()
.unwrap()
.insert(client_id, mock_api.clone());
Ok(Box::new((*mock_api).clone()))
}
}
#[derive(Debug, Clone)]
struct MockClientOsApi {
terminal_size: Size,
messages_to_server: Arc<Mutex<Vec<ClientToServerMsg>>>,
messages_from_server: Arc<Mutex<VecDeque<(ServerToClientMsg, ErrorContext)>>>,
}
impl MockClientOsApi {
fn new() -> Self {
Self {
terminal_size: Size { rows: 24, cols: 80 },
messages_to_server: Arc::new(Mutex::new(Vec::new())),
messages_from_server: Arc::new(Mutex::new(VecDeque::new())),
}
}
fn get_sent_messages(&self) -> Vec<ClientToServerMsg> {
self.messages_to_server.lock().unwrap().clone()
}
}
impl ClientOsApi for MockClientOsApi {
fn get_terminal_size_using_fd(&self, _fd: std::os::unix::io::RawFd) -> Size {
self.terminal_size
}
fn set_raw_mode(&mut self, _fd: std::os::unix::io::RawFd) {}
fn unset_raw_mode(&self, _fd: std::os::unix::io::RawFd) -> Result<(), nix::Error> {
Ok(())
}
fn get_stdout_writer(&self) -> Box<dyn std::io::Write> {
Box::new(std::io::sink())
}
fn get_stdin_reader(&self) -> Box<dyn std::io::BufRead> {
Box::new(std::io::Cursor::new(Vec::new()))
}
fn update_session_name(&mut self, _new_session_name: String) {}
fn read_from_stdin(&mut self) -> Result<Vec<u8>, &'static str> {
Ok(Vec::new())
}
fn box_clone(&self) -> Box<dyn ClientOsApi> {
Box::new(self.clone())
}
fn send_to_server(&self, msg: ClientToServerMsg) {
self.messages_to_server.lock().unwrap().push(msg);
}
fn recv_from_server(&self) -> Option<(ServerToClientMsg, ErrorContext)> {
self.messages_from_server.lock().unwrap().pop_front()
}
fn handle_signals(&self, _sigwinch_cb: Box<dyn Fn()>, _quit_cb: Box<dyn Fn()>) {}
fn connect_to_server(&self, _path: &std::path::Path) {}
fn load_palette(&self) -> Palette {
Palette::default()
}
fn enable_mouse(&self) -> anyhow::Result<()> {
Ok(())
}
fn disable_mouse(&self) -> anyhow::Result<()> {
Ok(())
}
fn stdin_poller(&self) -> crate::os_input_output::StdinPoller {
crate::os_input_output::StdinPoller::default()
}
}