use std::io::{Read, Write, Result};
use std::task::Poll;
use super::detail;
use super::Endpoint;
use crate::role::ServerRole;
use crate::handshake::{HttpHeader, Request, Response};
use crate::handshake::derive_accept_key;
use crate::error::HandshakeError;
use crate::stream::Stream;
impl<IO: Read + Write, Role: ServerRole> Endpoint<IO, Role> {
pub fn send_response<const N: usize>(
io: &mut IO,
buf: &mut [u8],
response: &Response<'_, '_, N>,
) -> Result<usize> {
match detail::send_response(io, buf, response, |io, buf| io.write(buf).into()) {
Poll::Ready(x) => x,
Poll::Pending => unreachable!(),
}
}
pub unsafe fn recv_request<'h, 'b: 'h, const N: usize>(
io: &mut IO,
buf: &mut [u8],
request: &mut Request<'h, 'b, N>,
) -> Result<usize> {
match detail::recv_request(io, buf, request, |io, buf| io.read(buf).into()) {
Poll::Ready(x) => x,
Poll::Pending => unreachable!(),
}
}
pub fn accept(mut io: IO, buf: &mut [u8], host: &str, path: &str) -> Result<Stream<IO, Role>> {
let mut other_headers = HttpHeader::new_storage();
let mut request = Request::new_storage(&mut other_headers);
let _ = unsafe { Self::recv_request(&mut io, buf, &mut request) }?;
if request.host != host.as_bytes() {
return Err(HandshakeError::Manual("host mismatch").into());
}
if request.path != path.as_bytes() {
return Err(HandshakeError::Manual("path mismatch").into());
}
let sec_accept = derive_accept_key(request.sec_key);
let response = Response::new(&sec_accept);
let _ = Self::send_response(&mut io, buf, &response)?;
Ok(Stream::new(io, Role::new()))
}
}
#[cfg(test)]
mod test {
use super::*;
use super::super::test::*;
use crate::role::Server;
#[test]
fn send_upgrade_response() {
fn run_limit(limit: usize) {
let mut rw = LimitReadWriter {
rbuf: Vec::new(),
wbuf: Vec::new(),
rlimit: 0,
wlimit: limit,
cursor: 0,
};
let response = Response::new(b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
let mut buf = vec![0u8; 1024];
let send_n =
Endpoint::<_, Server>::send_response(&mut rw, &mut buf, &response).unwrap();
assert_eq!(send_n, RESPONSE.len());
assert_eq!(&buf[..send_n], RESPONSE);
}
for i in 1..=256 {
run_limit(i);
}
}
#[test]
fn recv_upgrade_request() {
fn run_limit(limit: usize) {
let mut rw = LimitReadWriter {
rbuf: Vec::from(REQUEST),
wbuf: Vec::new(),
rlimit: limit,
wlimit: 0,
cursor: 0,
};
let mut buf = vec![0u8; 1024];
let mut headers = HttpHeader::new_storage();
let mut request = Request::new_storage(&mut headers);
let recv_n =
unsafe { Endpoint::<_, Server>::recv_request(&mut rw, &mut buf, &mut request) }
.unwrap();
assert_eq!(recv_n, REQUEST.len());
assert_eq!(request.host, b"www.example.com");
assert_eq!(request.path, b"/ws");
assert_eq!(request.sec_key, b"dGhlIHNhbXBsZSBub25jZQ==");
drop(request);
assert_eq!(&buf[..recv_n], REQUEST);
}
for i in 1..=256 {
run_limit(i);
}
}
#[test]
fn server_accept() {
let mut rw = LimitReadWriter {
rbuf: Vec::from(REQUEST),
wbuf: Vec::new(),
rlimit: 1,
wlimit: 1,
cursor: 0,
};
let mut buf = vec![0u8; 1024];
let _ = Endpoint::<_, Server>::accept(&mut rw, &mut buf, "www.example.com", "/ws");
}
}