use std::io::{self, Read, Write};
use super::frame::Role;
use super::frame_reader::{FrameReader, FrameReaderBuilder};
use super::frame_writer::FrameWriter;
use super::handshake::{self, HandshakeError};
use super::stream::{Client, ClientBuilder, Error, parse_ws_url};
use crate::buf::WriteBuf;
#[cfg(feature = "tls")]
use crate::tls::TlsCodec;
pub struct Connecting<S> {
stream: std::mem::ManuallyDrop<S>,
state: ConnectState,
#[cfg(feature = "tls")]
tls: Option<TlsCodec>,
reader_builder: FrameReaderBuilder,
write_buf_capacity: usize,
write_buf_headroom: usize,
ws_key: [u8; 24],
req_buf: Vec<u8>,
req_offset: usize,
resp_reader: crate::http::ResponseReader,
host: String,
path: String,
finished: bool, }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ConnectState {
#[cfg(feature = "tls")]
TlsWrite,
#[cfg(feature = "tls")]
TlsRead,
HttpSend,
HttpRecv,
Done,
}
impl ClientBuilder {
pub fn begin_connect<S: Read + Write>(
self,
stream: S,
url: &str,
) -> Result<Connecting<S>, Error> {
let parsed = parse_ws_url(url)?;
#[cfg(feature = "tls")]
let tls = if parsed.tls {
let config = match self.tls_config {
Some(c) => c,
None => crate::tls::TlsConfig::new().map_err(Error::Tls)?,
};
Some(TlsCodec::new(&config, parsed.host)?)
} else {
None
};
#[cfg(not(feature = "tls"))]
if parsed.tls {
return Err(Error::TlsNotEnabled);
}
let ws_key = handshake::generate_key();
#[cfg(feature = "tls")]
let initial_state = if tls.is_some() {
ConnectState::TlsWrite
} else {
ConnectState::HttpSend
};
#[cfg(not(feature = "tls"))]
let initial_state = ConnectState::HttpSend;
let mut connecting = Connecting {
stream: std::mem::ManuallyDrop::new(stream),
state: initial_state,
#[cfg(feature = "tls")]
tls,
reader_builder: self.reader_builder,
write_buf_capacity: self.write_buf_capacity,
write_buf_headroom: self.write_buf_headroom,
ws_key,
req_buf: Vec::new(),
req_offset: 0,
resp_reader: crate::http::ResponseReader::new(4096),
host: parsed.host.to_owned(),
path: parsed.path.to_owned(),
finished: false,
};
if matches!(initial_state, ConnectState::HttpSend) {
let path = connecting.path.clone();
connecting.prepare_http_request(&path);
}
Ok(connecting)
}
}
impl<S: Read + Write> Connecting<S> {
pub fn poll(&mut self) -> Result<Option<Client<S>>, Error> {
loop {
match self.state {
#[cfg(feature = "tls")]
ConnectState::TlsWrite => {
let tls = self
.tls
.as_mut()
.expect("TLS codec must exist in TLS handshake state");
match tls.write_tls_to(&mut *self.stream) {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e.into()),
}
if tls.is_handshaking() {
self.state = ConnectState::TlsRead;
} else {
self.state = ConnectState::HttpSend;
let path = self.path.clone();
self.prepare_http_request(&path);
}
}
#[cfg(feature = "tls")]
ConnectState::TlsRead => {
let tls = self
.tls
.as_mut()
.expect("TLS codec must exist in TLS handshake state");
match tls.read_tls_from(&mut *self.stream) {
Ok(0) => return Err(Error::Handshake(HandshakeError::MalformedHttp)),
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e.into()),
}
tls.process_new_packets()?;
if tls.wants_write() {
self.state = ConnectState::TlsWrite;
} else if !tls.is_handshaking() {
self.state = ConnectState::HttpSend;
let path = self.path.clone();
self.prepare_http_request(&path);
}
}
ConnectState::HttpSend => {
if self.req_offset >= self.req_buf.len() {
self.state = ConnectState::HttpRecv;
return Ok(None);
}
#[cfg(feature = "tls")]
if let Some(tls) = &mut self.tls {
if self.req_offset < self.req_buf.len() {
let data = &self.req_buf[self.req_offset..];
tls.encrypt(data)?;
self.req_offset = self.req_buf.len(); }
match tls.write_tls_to(&mut *self.stream) {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e.into()),
}
if tls.wants_write() {
return Ok(None);
}
self.state = ConnectState::HttpRecv;
return Ok(None);
}
{
let data = &self.req_buf[self.req_offset..];
let n = match (*self.stream).write(data) {
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e.into()),
};
if n == 0 {
return Err(Error::Io(io::Error::new(
io::ErrorKind::WriteZero,
"write returned 0 during handshake",
)));
}
self.req_offset += n;
if self.req_offset >= self.req_buf.len() {
self.state = ConnectState::HttpRecv;
}
}
return Ok(None);
}
ConnectState::HttpRecv => {
let mut tmp = [0u8; 4096];
let n = self.read_bytes(&mut tmp)?;
if n == 0 {
return Ok(None);
}
self.resp_reader
.read(&tmp[..n])
.map_err(|_| HandshakeError::MalformedHttp)?;
match self.resp_reader.next() {
Ok(Some(resp)) => {
if resp.status != 101 {
return Err(HandshakeError::UnexpectedStatus(resp.status).into());
}
let upgrade = resp
.header("Upgrade")
.ok_or(HandshakeError::MissingUpgrade)?;
if !upgrade.eq_ignore_ascii_case("websocket") {
return Err(HandshakeError::MissingUpgrade.into());
}
let conn = resp
.header("Connection")
.ok_or(HandshakeError::MissingConnection)?;
if !conn
.as_bytes()
.windows(7)
.any(|w| w.eq_ignore_ascii_case(b"upgrade"))
{
return Err(HandshakeError::MissingConnection.into());
}
let key_str = std::str::from_utf8(&self.ws_key)
.expect("base64 output is valid ASCII");
let accept = resp
.header("Sec-WebSocket-Accept")
.ok_or(HandshakeError::InvalidAcceptKey)?;
if !handshake::validate_accept(key_str, accept) {
return Err(HandshakeError::InvalidAcceptKey.into());
}
self.state = ConnectState::Done;
}
Ok(None) => return Ok(None),
Err(_) => return Err(HandshakeError::MalformedHttp.into()),
}
}
ConnectState::Done => {
return Ok(Some(self.finish()?));
}
}
}
}
pub fn wants_write(&self) -> bool {
matches!(
self.state,
ConnectState::HttpSend | if_tls!(ConnectState::TlsWrite)
)
}
pub fn wants_read(&self) -> bool {
matches!(
self.state,
ConnectState::HttpRecv | if_tls!(ConnectState::TlsRead)
)
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
fn prepare_http_request(&mut self, path: &str) {
let key_str = std::str::from_utf8(&self.ws_key).expect("base64 output is valid ASCII");
let headers = [
("Host", self.host.as_str()),
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-WebSocket-Key", key_str),
("Sec-WebSocket-Version", "13"),
];
let size = crate::http::request_size("GET", path, &headers);
let mut buf = vec![0u8; size];
let n = crate::http::write_request("GET", path, &headers, &mut buf)
.expect("request fits in handshake buffer");
self.req_buf = buf[..n].to_vec();
self.req_offset = 0;
}
fn finish(&mut self) -> Result<Client<S>, Error> {
self.finished = true;
let reader_builder = std::mem::replace(&mut self.reader_builder, FrameReader::builder());
let mut reader = reader_builder.role(Role::Client).build();
let remainder = self.resp_reader.remainder();
if !remainder.is_empty() {
reader
.read(remainder)
.map_err(|_| Error::Handshake(HandshakeError::MalformedHttp))?;
}
let stream = unsafe { std::mem::ManuallyDrop::take(&mut self.stream) };
Ok(Client::from_parts_internal(
stream,
reader,
FrameWriter::new(Role::Client),
WriteBuf::new(self.write_buf_capacity, self.write_buf_headroom),
))
}
fn read_bytes(&mut self, dst: &mut [u8]) -> Result<usize, Error> {
#[cfg(feature = "tls")]
if let Some(tls) = &mut self.tls {
return match tls.read_tls_from(&mut *self.stream) {
Ok(0) => Err(Error::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed during TLS handshake",
))),
Ok(_) => {
tls.process_new_packets()?;
tls.read_plaintext(dst).map_err(Error::Tls)
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
Err(e) => Err(e.into()),
};
}
match (*self.stream).read(dst) {
Ok(n) => Ok(n),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
Err(e) => Err(e.into()),
}
}
}
impl<S> Drop for Connecting<S> {
fn drop(&mut self) {
if !self.finished {
unsafe {
std::mem::ManuallyDrop::drop(&mut self.stream);
}
}
}
}
#[cfg(feature = "tls")]
macro_rules! if_tls {
($pat:pat) => {
$pat
};
}
#[cfg(not(feature = "tls"))]
macro_rules! if_tls {
($pat:pat) => {
ConnectState::Done
}; }
use if_tls;