use log::*;
use log::Level::Debug;
use sha1::Sha1;
use crate::init::CONFIG;
use crate::message::stomp_message::StompMessage;
#[derive(Debug, PartialEq)]
pub enum WsUpgradeError {
OriginDenied,
Syntax,
HostMissing,
ProtocolError,
}
pub fn ws_validate_hdrs(message: &StompMessage) -> Result<(), WsUpgradeError> {
if ! ws_validate_hdr_websocket_origin(message) {
return Err(WsUpgradeError::OriginDenied);
}
if ! ws_validate_hdr_host(message) {
return Err(WsUpgradeError::HostMissing);
}
if ! ws_validate_hdr_sec_protocol(message) {
return Err(WsUpgradeError::ProtocolError);
}
if ! ws_validate_hdr_upgrade(message) {
return Err(WsUpgradeError::ProtocolError);
}
if ws_validate_hdr_connection(message) &&
ws_validate_hdr_upgrade(message) &&
ws_validate_hdr_websocket_version(message) &&
ws_validate_hdr_websocket_key(message) {
return Ok(());
}
Err(WsUpgradeError::Syntax)
}
pub fn ws_get_websocket_accept_key(message: &StompMessage) -> String {
if let Some(web_socket_key) = message.get_header_case_insensitive("Sec-WebSocket-Key") {
let mut sha1 = Sha1::new();
sha1.update(web_socket_key.trim().as_bytes());
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" as &[u8]);
let sha_res = sha1.digest().bytes();
return base64::encode(sha_res.as_ref());
}
return String::from("");
}
pub fn ws_parse_request_line(message: &StompMessage) -> Result<String,()> {
match message.get_header("request-line") {
Some(line) => {
for (i, part) in line.split_whitespace().enumerate() {
if i == 0 && ! part.eq("GET") {
return Err(());
}
if i == 2 && ! part.eq("HTTP/1.1") {
return Err(());
}
if i == 1 {
return Ok(String::from(part));
}
}
Err(())
},
_ => Err(())
}
}
pub fn ws_validate_hdr_host(message: &StompMessage) -> bool {
if let Some(host) = message.get_header_case_insensitive("Host") {
if ! CONFIG.name.eq(host.trim()) {
if log_enabled!(Debug) {
debug!("unexpected host header: {}", host.trim());
}
}
return true;
}
false
}
fn ws_validate_hdr_upgrade(message: &StompMessage) -> bool {
if let Some(value) = message.get_header_case_insensitive("Upgrade") {
return "websocket".eq(value.trim());
}
false
}
fn ws_validate_hdr_connection(message: &StompMessage) -> bool {
if let Some(value) = message.get_header_case_insensitive("Connection") {
for con_flag in value.split(",") {
if "Upgrade".eq_ignore_ascii_case(con_flag.trim()) {
return true;
}
}
}
false
}
fn ws_validate_hdr_sec_protocol(message: &StompMessage) -> bool {
if let Some(protocol) = message.get_header_case_insensitive("Sec-WebSocket-Protocol") {
return "stomp".eq_ignore_ascii_case(protocol.trim());
}
false
}
fn ws_validate_hdr_websocket_version(message: &StompMessage) -> bool {
if let Some(version) = message.get_header_case_insensitive("Sec-WebSocket-Version") {
return "13".eq(version.trim());
}
false
}
fn ws_validate_hdr_websocket_origin(message: &StompMessage) -> bool {
if let Some(origin) = message.get_header_case_insensitive("Origin") {
if let Some(allowed_origins) = &CONFIG.websockets_origin {
if "*".eq(allowed_origins) {
return true;
}
for allowed_origin in allowed_origins.split_whitespace() {
if allowed_origin.eq_ignore_ascii_case(origin) {
return true;
}
}
}
}
true
}
fn ws_validate_hdr_websocket_key(message: &StompMessage) -> bool {
message.get_header_case_insensitive("Sec-WebSocket-Key").is_some()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::stomp_message::MessageType::Http;
use crate::message::stomp_message::Ownership;
#[test]
fn test_happy_path() {
let mut message = StompMessage::new(Ownership::Session);
message.message_type = Http;
message.add_header("request-line", "GET / HTTP/1.1");
message.add_header("Host", "romp");
message.add_header("Connection", "Upgrade");
message.add_header("Upgrade", "websocket");
message.add_header("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw==");
message.add_header("Sec-WebSocket-Protocol", "stomp");
message.add_header("Sec-WebSocket-Version", "13");
message.add_header("Origin", "http://tp23.org");
match ws_validate_hdrs(&message) {
Ok(_) => {
assert_eq!("HSmrc0sMlYUkAGmm5OPpG2HaGWk=", ws_get_websocket_accept_key(&message));
},
_ => panic!("Headers not ok")
}
}
#[test]
fn test_happy_path2() {
let mut message = StompMessage::new(Ownership::Session);
message.message_type = Http;
message.add_header("request-line", "GET / HTTP/1.1");
message.add_header("Host", "romp");
message.add_header("Connection", "Upgrade");
message.add_header("Upgrade", "websocket");
message.add_header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==");
message.add_header("Sec-WebSocket-Protocol", "stomp");
message.add_header("Sec-WebSocket-Version", "13");
message.add_header("Origin", "http://tp23.org");
match ws_validate_hdrs(&message) {
Ok(_) => {
assert_eq!("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", ws_get_websocket_accept_key(&message));
},
_ => panic!("Headers not ok")
}
}
#[test]
fn test_neg_no_host() {
let mut message = StompMessage::new(Ownership::Session);
message.message_type = Http;
message.add_header("request-line", "GET / HTTP/1.1");
message.add_header("Connection", "Upgrade");
message.add_header("Upgrade", "websocket");
message.add_header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==");
message.add_header("Sec-WebSocket-Protocol", "stomp");
message.add_header("Sec-WebSocket-Version", "13");
message.add_header("Origin", "http://tp23.org");
match ws_validate_hdrs(&message) {
Err(WsUpgradeError::HostMissing) => {
},
_ => panic!("Host checking failed")
}
}
#[test]
fn test_neg_no_upgrade() {
let mut message = StompMessage::new(Ownership::Session);
message.message_type = Http;
message.add_header("request-line", "GET / HTTP/1.1");
message.add_header("Host", "romp");
message.add_header("Connection", "Upgrade");
message.add_header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==");
message.add_header("Sec-WebSocket-Protocol", "stomp");
message.add_header("Sec-WebSocket-Version", "13");
message.add_header("Origin", "http://tp23.org");
match ws_validate_hdrs(&message) {
Err(WsUpgradeError::ProtocolError) => {
},
_ => panic!("Upgrade header check failed")
}
}
}