use hyper::body::Incoming;
use hyper::upgrade::Upgraded;
use hyper::Request;
use hyper::Response;
use hyper::StatusCode;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use hyper_util::rt::TokioIo;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use std::future::Future;
use std::pin::Pin;
use crate::Role;
use crate::WebSocket;
use crate::WebSocketError;
pub async fn client<S, E, B>(
executor: &E,
request: Request<B>,
socket: S,
) -> Result<(WebSocket<TokioIo<Upgraded>>, Response<Incoming>), WebSocketError>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
E: hyper::rt::Executor<Pin<Box<dyn Future<Output = ()> + Send>>>,
B: hyper::body::Body + 'static + Send,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let (mut sender, conn) =
hyper::client::conn::http1::handshake(TokioIo::new(socket)).await?;
let fut = Box::pin(async move {
if let Err(e) = conn.with_upgrades().await {
eprintln!("Error polling connection: {}", e);
}
});
executor.execute(fut);
let mut response = sender.send_request(request).await?;
verify(&response)?;
match hyper::upgrade::on(&mut response).await {
Ok(upgraded) => Ok((
WebSocket::after_handshake(TokioIo::new(upgraded), Role::Client),
response,
)),
Err(e) => Err(e.into()),
}
}
pub fn generate_key() -> String {
let r: [u8; 16] = rand::random();
STANDARD.encode(r)
}
fn verify(response: &Response<Incoming>) -> Result<(), WebSocketError> {
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(WebSocketError::InvalidStatusCode(
response.status().as_u16(),
));
}
let headers = response.headers();
if !headers
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(WebSocketError::InvalidUpgradeHeader);
}
if !headers
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(WebSocketError::InvalidConnectionHeader);
}
Ok(())
}