use bytes::Bytes;
use clasp_core::{
codec, HelloMessage, Message, SetMessage, SubscribeMessage, Value, PROTOCOL_VERSION,
};
use clasp_test_utils::TestRouter;
use clasp_transport::{
Transport, TransportEvent, TransportReceiver, TransportSender, WebSocketTransport,
};
use std::time::Duration;
use tokio::time::timeout;
async fn recv_message(
receiver: &mut impl TransportReceiver,
max_wait: Duration,
) -> Option<Message> {
let deadline = tokio::time::Instant::now() + max_wait;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return None;
}
match timeout(remaining, receiver.recv()).await {
Ok(Some(TransportEvent::Data(data))) => {
match codec::decode(&data) {
Ok((msg, _)) => return Some(msg),
Err(_) => continue, }
}
Ok(Some(TransportEvent::Connected)) => continue,
Ok(Some(TransportEvent::Disconnected { .. })) => return None,
Ok(Some(TransportEvent::Error(_))) => return None,
Ok(None) => return None,
Err(_) => return None, }
}
}
async fn complete_handshake(
sender: &impl TransportSender,
receiver: &mut impl TransportReceiver,
name: &str,
) -> bool {
let hello = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: name.to_string(),
features: vec![],
capabilities: None,
token: None,
});
if sender.send(codec::encode(&hello).unwrap()).await.is_err() {
return false;
}
let mut got_welcome = false;
let mut got_snapshot = false;
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
while (!got_welcome || !got_snapshot) && tokio::time::Instant::now() < deadline {
if let Some(msg) = recv_message(receiver, Duration::from_millis(500)).await {
match msg {
Message::Welcome(_) => got_welcome = true,
Message::Snapshot(_) => got_snapshot = true,
_ => {}
}
} else {
break;
}
}
got_welcome
}
#[tokio::test]
async fn test_malformed_message_returns_error_400() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
assert!(
complete_handshake(&sender, &mut receiver, "MalformedTest").await,
"Handshake should succeed"
);
let garbage = Bytes::from(vec![0xFF, 0xFE, 0xFD, 0xFC, 0x00, 0x01, 0x02]);
sender.send(garbage).await.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(2)).await;
match response {
Some(Message::Error(err)) => {
assert!(
err.code == 400 || err.code == 0,
"Malformed message should return error code 400, got {}",
err.code
);
}
None => {
}
Some(other) => {
assert!(
!matches!(other, Message::Ack(_)),
"Server should NOT ACK malformed messages"
);
}
}
}
#[tokio::test]
async fn test_truncated_message_returns_error_400() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let hello = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: "Test".to_string(),
features: vec![],
capabilities: None,
token: None,
});
let bytes = codec::encode(&hello).expect("Failed to encode");
let truncated = Bytes::from(bytes.to_vec()[..3.min(bytes.len())].to_vec());
sender.send(truncated).await.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(2)).await;
match response {
Some(Message::Error(err)) => {
assert!(
err.code == 400 || err.code == 0,
"Truncated message should return error code 400, got {}",
err.code
);
}
None => {
}
Some(other) => {
assert!(
!matches!(other, Message::Welcome(_)),
"Server should NOT send WELCOME for truncated HELLO"
);
}
}
}
#[tokio::test]
async fn test_wrong_protocol_version_returns_error_505() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let hello = Message::Hello(HelloMessage {
version: 99, name: "BadVersion".to_string(),
features: vec![],
capabilities: None,
token: None,
});
sender
.send(codec::encode(&hello).expect("Failed to encode"))
.await
.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(2)).await;
match response {
Some(Message::Error(err)) => {
assert!(
err.code == 505 || err.code == 400,
"Wrong protocol version should return error code 505 or 400, got {}",
err.code
);
}
Some(Message::Welcome(_)) => {
eprintln!("Note: Server accepted unsupported version 99 (forward-compatible mode)");
}
None => {
}
Some(other) => {
panic!(
"Expected ERROR or WELCOME for version mismatch, got {:?}",
std::mem::discriminant(&other)
);
}
}
}
#[tokio::test]
async fn test_message_before_hello_returns_error_401() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let set = Message::Set(SetMessage {
address: "/test".to_string(),
value: Value::Int(1),
revision: None,
lock: false,
unlock: false,
ttl: None,
});
sender
.send(codec::encode(&set).expect("Failed to encode"))
.await
.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(2)).await;
match response {
Some(Message::Error(err)) => {
assert!(
err.code == 401 || err.code == 400,
"Message before HELLO should return error code 401 or 400, got {}",
err.code
);
}
Some(Message::Ack(_)) => {
panic!("Server MUST NOT ACK messages before HELLO handshake is complete");
}
None => {
}
Some(_) => {
}
}
}
#[tokio::test]
async fn test_duplicate_hello() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let hello = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: "First".to_string(),
features: vec![],
capabilities: None,
token: None,
});
sender
.send(codec::encode(&hello).expect("Failed to encode"))
.await
.expect("Failed to send");
let got_welcome = loop {
match timeout(Duration::from_secs(2), receiver.recv()).await {
Ok(Some(TransportEvent::Data(data))) => {
let (msg, _) = codec::decode(&data).expect("Failed to decode");
if matches!(msg, Message::Welcome(_)) {
break true;
}
}
Ok(Some(TransportEvent::Connected)) => continue,
_ => break false,
}
};
assert!(got_welcome, "Expected WELCOME message");
let hello2 = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: "Second".to_string(),
features: vec![],
capabilities: None,
token: None,
});
sender
.send(codec::encode(&hello2).expect("Failed to encode"))
.await
.expect("Failed to send");
let _response = timeout(Duration::from_millis(500), receiver.recv()).await;
}
#[tokio::test]
async fn test_very_long_address() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let hello = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: "LongAddressTest".to_string(),
features: vec![],
capabilities: None,
token: None,
});
sender
.send(codec::encode(&hello).expect("Failed to encode"))
.await
.expect("Failed to send");
loop {
match timeout(Duration::from_secs(2), receiver.recv()).await {
Ok(Some(TransportEvent::Data(data))) => {
let (msg, _) = codec::decode(&data).expect("Failed to decode");
if matches!(msg, Message::Snapshot(_)) {
break;
}
}
Ok(Some(TransportEvent::Connected)) => continue,
_ => break,
}
}
let long_addr = format!("/{}", "a".repeat(10_000));
let set = Message::Set(SetMessage {
address: long_addr,
value: Value::Int(1),
revision: None,
lock: false,
unlock: false,
ttl: None,
});
sender
.send(codec::encode(&set).expect("Failed to encode"))
.await
.expect("Failed to send");
let _response = timeout(Duration::from_secs(1), receiver.recv()).await;
}
#[tokio::test]
async fn test_empty_address_returns_error_400() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
assert!(
complete_handshake(&sender, &mut receiver, "EmptyAddressTest").await,
"Handshake should succeed"
);
let set = Message::Set(SetMessage {
address: "".to_string(), value: Value::Int(1),
revision: None,
lock: false,
unlock: false,
ttl: None,
});
sender
.send(codec::encode(&set).expect("Failed to encode"))
.await
.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(2)).await;
match response {
Some(Message::Error(err)) => {
assert_eq!(
err.code, 400,
"Empty address should return error code 400, got {}",
err.code
);
}
Some(Message::Ack(_)) => {
eprintln!("Note: Server accepted empty address (permissive mode)");
}
None => {
}
Some(_) => {
}
}
}
#[tokio::test]
async fn test_invalid_address_returns_error_400() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
assert!(
complete_handshake(&sender, &mut receiver, "InvalidAddressTest").await,
"Handshake should succeed"
);
let invalid_addresses = vec![
"//double/slash", "no/leading/slash", "/unclosed/**/wildcard/**", ];
for addr in invalid_addresses {
let set = Message::Set(SetMessage {
address: addr.to_string(),
value: Value::Int(1),
revision: None,
lock: false,
unlock: false,
ttl: None,
});
sender
.send(codec::encode(&set).expect("Failed to encode"))
.await
.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_millis(500)).await;
match response {
Some(Message::Error(err)) => {
assert!(
err.code == 400 || err.code == 0,
"Invalid address '{}' should return error code 400, got {}",
addr,
err.code
);
}
Some(Message::Ack(_)) => {
}
None => {
}
Some(_) => {
}
}
}
}
#[tokio::test]
async fn test_rapid_disconnect_reconnect() {
let router = TestRouter::start().await;
for i in 0..5 {
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let hello = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: format!("Rapid{}", i),
features: vec![],
capabilities: None,
token: None,
});
sender
.send(codec::encode(&hello).expect("Failed to encode"))
.await
.expect("Failed to send");
let _ = timeout(Duration::from_millis(100), receiver.recv()).await;
sender.close().await.expect("Failed to close");
}
}
#[tokio::test]
async fn test_connection_to_closed_port() {
let result = timeout(
Duration::from_secs(2),
WebSocketTransport::connect("ws://127.0.0.1:1"),
)
.await;
match result {
Ok(Err(_)) => {} Err(_) => {} Ok(Ok(_)) => panic!("Should not connect to closed port"),
}
}
#[tokio::test]
async fn test_special_characters_in_address() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
let hello = Message::Hello(HelloMessage {
version: PROTOCOL_VERSION,
name: "SpecialChars".to_string(),
features: vec![],
capabilities: None,
token: None,
});
sender
.send(codec::encode(&hello).expect("Failed to encode"))
.await
.expect("Failed to send");
loop {
match timeout(Duration::from_secs(2), receiver.recv()).await {
Ok(Some(TransportEvent::Data(data))) => {
let (msg, _) = codec::decode(&data).expect("Failed to decode");
if matches!(msg, Message::Snapshot(_)) {
break;
}
}
Ok(Some(TransportEvent::Connected)) => continue,
_ => break,
}
}
let special_addresses = vec![
"/path/with spaces",
"/path/with\ttabs",
"/unicode/\u{65e5}\u{672c}\u{8a9e}",
"/emoji/\u{1f3b5}",
"/symbols/@#$%",
];
for addr in special_addresses {
let set = Message::Set(SetMessage {
address: addr.to_string(),
value: Value::Int(1),
revision: None,
lock: false,
unlock: false,
ttl: None,
});
sender
.send(codec::encode(&set).expect("Failed to encode"))
.await
.expect("Failed to send");
let _ = timeout(Duration::from_millis(100), receiver.recv()).await;
}
}
#[tokio::test]
async fn test_unauthorized_write_to_locked_address_returns_error_403() {
let router = TestRouter::start().await;
let (owner_sender, mut owner_receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect owner");
assert!(
complete_handshake(&owner_sender, &mut owner_receiver, "Owner").await,
"Owner handshake should succeed"
);
let set_locked = Message::Set(SetMessage {
address: "/locked/value".to_string(),
value: Value::Int(100),
revision: None,
lock: true, unlock: false,
ttl: None,
});
owner_sender
.send(codec::encode(&set_locked).expect("Failed to encode"))
.await
.expect("Failed to send");
let owner_response = recv_message(&mut owner_receiver, Duration::from_secs(2)).await;
assert!(
matches!(owner_response, Some(Message::Ack(_))),
"Owner should receive ACK for locked set"
);
let (intruder_sender, mut intruder_receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect intruder");
assert!(
complete_handshake(&intruder_sender, &mut intruder_receiver, "Intruder").await,
"Intruder handshake should succeed"
);
let set_intruder = Message::Set(SetMessage {
address: "/locked/value".to_string(),
value: Value::Int(999), revision: None,
lock: false,
unlock: false,
ttl: None,
});
intruder_sender
.send(codec::encode(&set_intruder).expect("Failed to encode"))
.await
.expect("Failed to send");
let intruder_response = recv_message(&mut intruder_receiver, Duration::from_secs(2)).await;
match intruder_response {
Some(Message::Error(err)) => {
assert!(
err.code >= 400 && err.code < 500,
"Write to locked address should return 4xx error, got {}",
err.code
);
}
Some(Message::Ack(_)) => {
panic!("Server MUST NOT ACK writes to locked addresses from non-owners");
}
None => {
eprintln!("Note: Server silently ignored write to locked address");
}
Some(other) => {
panic!(
"Unexpected response to locked address write: {:?}",
std::mem::discriminant(&other)
);
}
}
}
#[tokio::test]
async fn test_subscribe_invalid_pattern_returns_error_400() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
assert!(
complete_handshake(&sender, &mut receiver, "PatternTest").await,
"Handshake should succeed"
);
let invalid_patterns = vec![
"", "no/leading/slash", ];
for pattern in invalid_patterns {
let subscribe = Message::Subscribe(SubscribeMessage {
id: 1,
pattern: pattern.to_string(),
types: vec![],
options: None,
});
sender
.send(codec::encode(&subscribe).expect("Failed to encode"))
.await
.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(1)).await;
match response {
Some(Message::Error(err)) => {
eprintln!(
"Note: Server returned error {} for pattern '{}'",
err.code, pattern
);
}
Some(Message::Ack(_)) => {
eprintln!(
"Note: Server accepted pattern '{}' (permissive mode)",
pattern
);
}
None => {
}
Some(_) => {
}
}
}
}
#[tokio::test]
async fn test_duplicate_subscription_id() {
let router = TestRouter::start().await;
let (sender, mut receiver) = WebSocketTransport::connect(&router.url())
.await
.expect("Failed to connect");
assert!(
complete_handshake(&sender, &mut receiver, "DuplicateSubTest").await,
"Handshake should succeed"
);
let subscribe1 = Message::Subscribe(SubscribeMessage {
id: 42,
pattern: "/test/a".to_string(),
types: vec![],
options: None,
});
sender
.send(codec::encode(&subscribe1).expect("Failed to encode"))
.await
.expect("Failed to send");
let _ = recv_message(&mut receiver, Duration::from_secs(1)).await;
let subscribe2 = Message::Subscribe(SubscribeMessage {
id: 42, pattern: "/test/b".to_string(),
types: vec![],
options: None,
});
sender
.send(codec::encode(&subscribe2).expect("Failed to encode"))
.await
.expect("Failed to send");
let response = recv_message(&mut receiver, Duration::from_secs(1)).await;
match response {
Some(Message::Ack(_)) => {
}
Some(Message::Error(err)) => {
assert!(
err.code == 400 || err.code == 409,
"Duplicate subscription should return 400 or 409, got {}",
err.code
);
}
None => {
}
Some(_) => {}
}
}