use crate::headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use crate::{Request, Response, Result, SilentError, StatusCode, header};
pub fn websocket_handler(req: &Request) -> Result<Response> {
let mut res = Response::empty();
let req_headers = req.headers();
if !req_headers.contains_key(header::UPGRADE) {
return Err(SilentError::BusinessError {
code: StatusCode::BAD_REQUEST,
msg: "bad request: not upgrade".to_string(),
});
}
if !req_headers
.get(header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_lowercase() == "websocket")
.unwrap_or(false)
{
return Err(SilentError::BusinessError {
code: StatusCode::BAD_REQUEST,
msg: "bad request: not websocket".to_string(),
});
}
let sec_ws_key = if let Some(key) = req_headers.typed_get::<SecWebsocketKey>() {
key
} else {
return Err(SilentError::BusinessError {
code: StatusCode::BAD_REQUEST,
msg: "bad request: sec_websocket_key is not exist in request headers".to_string(),
});
};
res.set_status(StatusCode::SWITCHING_PROTOCOLS);
res.headers.typed_insert(Connection::upgrade());
res.headers.typed_insert(Upgrade::websocket());
res.headers
.typed_insert(SecWebsocketAccept::from(sec_ws_key));
Ok(res)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::header::HeaderValue;
#[test]
fn test_websocket_handler_valid_request() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
let res = res.unwrap();
assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
}
#[test]
fn test_websocket_handler_missing_upgrade_header() {
let req = Request::empty();
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_websocket_handler_invalid_upgrade_value() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("http/2"));
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_websocket_handler_upgrade_case_insensitive() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("WebSocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
let res = res.unwrap();
assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
}
#[test]
fn test_websocket_handler_missing_sec_websocket_key() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_websocket_handler_response_headers() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req).unwrap();
assert!(
res.headers()
.get("connection")
.map(|v| v.to_str().unwrap().to_lowercase().contains("upgrade"))
.unwrap_or(false)
);
assert!(
res.headers()
.get("upgrade")
.map(|v| v.to_str().unwrap().to_lowercase() == "websocket")
.unwrap_or(false)
);
assert!(res.headers().get("sec-websocket-accept").is_some());
}
#[test]
fn test_websocket_handler_sec_websocket_accept() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req).unwrap();
let accept = res.headers().get("sec-websocket-accept").unwrap();
assert!(!accept.to_str().unwrap().is_empty());
}
#[test]
fn test_websocket_handler_response_status() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req).unwrap();
assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
}
#[test]
fn test_websocket_handler_error_not_upgrade() {
let req = Request::empty();
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
assert!(err.message().contains("not upgrade"));
}
#[test]
fn test_websocket_handler_error_not_websocket() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("http/2"));
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
assert!(err.message().contains("not websocket"));
}
#[test]
fn test_websocket_handler_error_missing_key() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
assert!(err.message().contains("sec_websocket_key"));
}
#[test]
fn test_websocket_handler_with_connection_header() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
}
#[test]
fn test_websocket_handler_empty_key() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut()
.insert("sec-websocket-key", HeaderValue::from_static(""));
let res = websocket_handler(&req);
assert!(res.is_ok() || res.is_err());
}
#[test]
fn test_websocket_handler_different_key() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXo="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
let res = res.unwrap();
let accept = res.headers().get("sec-websocket-accept").unwrap();
assert!(!accept.to_str().unwrap().is_empty());
}
#[test]
fn test_websocket_handler_upgrade_lowercase() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
}
#[test]
fn test_websocket_handler_upgrade_uppercase() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("WEBSOCKET"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
}
#[test]
fn test_websocket_handler_upgrade_mixed_case() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("WeBsOcKeT"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req);
assert!(res.is_ok());
}
#[test]
fn test_websocket_handler_with_additional_headers() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
req.headers_mut()
.insert("sec-websocket-version", HeaderValue::from_static("13"));
req.headers_mut()
.insert("sec-websocket-protocol", HeaderValue::from_static("chat"));
let res = websocket_handler(&req);
assert!(res.is_ok());
let res = res.unwrap();
assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
}
#[test]
fn test_websocket_handler_status_switching_protocols() {
let mut req = Request::empty();
req.headers_mut()
.insert("upgrade", HeaderValue::from_static("websocket"));
req.headers_mut().insert(
"sec-websocket-key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
req.headers_mut()
.insert("connection", HeaderValue::from_static("Upgrade"));
let res = websocket_handler(&req).unwrap();
assert_eq!(res.status(), 101); }
#[test]
fn test_websocket_handler_bad_request_status() {
let req = Request::empty();
let res = websocket_handler(&req);
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.status(), 400); }
}