use hyper::upgrade::Upgraded;
use hyper::Body;
use hyper::Request;
use hyper::Response;
use hyper::StatusCode;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use crate::Role;
use crate::WebSocket;
pub async fn client<S, E>(
executor: &E,
request: Request<Body>,
socket: S,
) -> Result<(WebSocket<Upgraded>, Response<Body>), Box<dyn Error + Send + Sync>>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
E: hyper::rt::Executor<Pin<Box<dyn Future<Output = ()> + Send>>>,
{
let (mut sender, conn) = hyper::client::conn::handshake(socket).await?;
let fut = Box::pin(async move {
if let Err(e) = conn.await {
eprintln!("Error polling connection: {}", e);
}
});
executor.execute(fut);
let mut response = sender.send_request(request).await?;
verify(&response)?;
match hyper::upgrade::on(&mut response).await {
Ok(upgraded) => {
Ok((WebSocket::after_handshake(upgraded, Role::Client), response))
}
Err(e) => Err(e.into()),
}
}
pub fn generate_key() -> String {
let r: [u8; 16] = rand::random();
STANDARD.encode(&r)
}
fn verify(
response: &Response<Body>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err("Invalid status code".into());
}
let headers = response.headers();
if !headers
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err("Invalid Upgrade header".into());
}
if !headers
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err("Invalid Connection header".into());
}
Ok(())
}