use std::io::{Read, Write, Result};
use std::task::Poll;
use super::detail;
use super::Endpoint;
use crate::role::ClientRole;
use crate::handshake::{HttpHeader, Request, Response};
use crate::handshake::{new_sec_key, derive_accept_key};
use crate::error::HandshakeError;
use crate::stream::Stream;
impl<IO: Read + Write, Role: ClientRole> Endpoint<IO, Role> {
pub fn send_request<const N: usize>(
io: &mut IO,
buf: &mut [u8],
request: &Request<'_, '_, N>,
) -> Result<usize> {
match detail::send_request(io, buf, request, |io, buf| io.write(buf).into()) {
Poll::Ready(x) => x,
Poll::Pending => unreachable!(),
}
}
pub unsafe fn recv_response<'h, 'b: 'h, const N: usize>(
io: &mut IO,
buf: &mut [u8],
response: &mut Response<'h, 'b, N>,
) -> Result<usize> {
match detail::recv_response(io, buf, response, |io, buf| io.read(buf).into()) {
Poll::Ready(x) => x,
Poll::Pending => unreachable!(),
}
}
pub fn connect(mut io: IO, buf: &mut [u8], host: &str, path: &str) -> Result<Stream<IO, Role>> {
let sec_key = new_sec_key();
let sec_accept = derive_accept_key(&sec_key);
let request = Request::new(path.as_bytes(), host.as_bytes(), &sec_key);
let _ = Self::send_request(&mut io, buf, &request)?;
let mut other_headers = HttpHeader::new_storage();
let mut response = Response::new_storage(&mut other_headers);
let _ = unsafe { Self::recv_response(&mut io, buf, &mut response) }?;
if response.sec_accept != sec_accept {
return Err(HandshakeError::SecWebSocketAccept.into());
}
Ok(Stream::new(io, Role::new()))
}
}
#[cfg(test)]
mod test {
use std::error::Error;
use super::*;
use super::super::test::*;
use crate::error::HandshakeError;
use crate::role::Client;
#[test]
fn send_upgrade_request() {
fn run_limit(limit: usize) {
let mut rw = LimitReadWriter {
rbuf: Vec::new(),
wbuf: Vec::new(),
rlimit: 0,
wlimit: limit,
cursor: 0,
};
let request = Request::new(b"/ws", b"www.example.com", b"dGhlIHNhbXBsZSBub25jZQ==");
let mut buf = vec![0u8; 1024];
let send_n = Endpoint::<_, Client>::send_request(&mut rw, &mut buf, &request).unwrap();
assert_eq!(send_n, REQUEST.len());
assert_eq!(&buf[..send_n], REQUEST);
}
for i in 1..=256 {
run_limit(i);
}
}
#[test]
fn recv_upgrade_response() {
fn run_limit(limit: usize) {
let mut rw = LimitReadWriter {
rbuf: Vec::from(RESPONSE),
wbuf: Vec::new(),
rlimit: limit,
wlimit: 0,
cursor: 0,
};
let mut buf = vec![0u8; 1024];
let mut headers = HttpHeader::new_storage();
let mut response = Response::new_storage(&mut headers);
let recv_n =
unsafe { Endpoint::<_, Client>::recv_response(&mut rw, &mut buf, &mut response) }
.unwrap();
assert_eq!(recv_n, RESPONSE.len());
assert_eq!(response.sec_accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
drop(response);
assert_eq!(&buf[..recv_n], RESPONSE);
}
for i in 1..=256 {
run_limit(i);
}
}
#[test]
fn client_connect() {
let mut rw = LimitReadWriter {
rbuf: Vec::from(RESPONSE),
wbuf: Vec::new(),
rlimit: 1,
wlimit: 1,
cursor: 0,
};
let mut buf = vec![0u8; 1024];
let stream = Endpoint::<_, Client>::connect(&mut rw, &mut buf, "example.com", "/");
if let Err(e) = stream {
let e = e.source().unwrap();
let e: &HandshakeError = e.downcast_ref().unwrap();
assert_eq!(*e, HandshakeError::SecWebSocketAccept);
}
}
}