use std::{
pin::Pin,
task::{Context, Poll},
};
use hyper_util::rt::TokioIo;
use pin_project::pin_project;
use sha1::{Digest, Sha1};
use bytes::Bytes;
use super::{HttpStream, Negotiation, Role, WebSocket};
#[cfg(feature = "axum")]
use {
super::{Options, MAX_PAYLOAD_READ, MAX_READ_BUFFER},
crate::{compression::WebSocketExtensions, Result},
http_body_util::Empty,
hyper::{header, Response},
std::future::Future,
};
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
#[cfg(feature = "axum")]
pub struct IncomingUpgrade {
key: String,
on_upgrade: hyper::upgrade::OnUpgrade,
extensions: Option<WebSocketExtensions>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
#[cfg(feature = "axum")]
impl IncomingUpgrade {
pub fn upgrade(self, options: Options) -> Result<(Response<Empty<Bytes>>, UpgradeFut)> {
let builder = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, "upgrade")
.header(header::UPGRADE, "websocket")
.header(header::SEC_WEBSOCKET_ACCEPT, self.key);
let (builder, extensions) = match (self.extensions, options.compression.as_ref()) {
(Some(client_offer), Some(server_offer)) => {
let offer = server_offer.merge(&client_offer);
let response = builder.header(header::SEC_WEBSOCKET_EXTENSIONS, offer.to_string());
(response, Some(offer))
}
_ => (builder, None),
};
let response = builder
.body(Empty::new())
.expect("bug: failed to build response");
let max_read_buffer = options.max_read_buffer.unwrap_or(
options
.max_payload_read
.map(|payload_read| payload_read * 2)
.unwrap_or(MAX_READ_BUFFER),
);
let stream = UpgradeFut {
inner: self.on_upgrade,
negotiation: Some(Negotiation {
extensions,
compression_level: options
.compression
.as_ref()
.map(|compression| compression.level),
max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
max_backpressure_write_boundary: options.max_backpressure_write_boundary,
fragmentation: options.fragmentation.clone(),
max_read_buffer,
utf8: options.check_utf8,
}),
};
Ok((response, stream))
}
}
#[cfg(feature = "axum")]
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
impl<S> axum_core::extract::FromRequestParts<S> for IncomingUpgrade
where
S: Sync,
{
type Rejection = hyper::StatusCode;
fn from_request_parts(
parts: &mut http::request::Parts,
_state: &S,
) -> impl Future<Output = std::result::Result<Self, Self::Rejection>> + Send {
use std::str::FromStr;
async move {
let key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(http::StatusCode::BAD_REQUEST)?;
if parts
.headers
.get(header::SEC_WEBSOCKET_VERSION)
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err(hyper::StatusCode::BAD_REQUEST);
}
let extensions = parts
.headers
.get(header::SEC_WEBSOCKET_EXTENSIONS)
.and_then(|h| h.to_str().ok())
.map(WebSocketExtensions::from_str)
.and_then(std::result::Result::ok);
let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
.ok_or(hyper::StatusCode::BAD_REQUEST)?;
Ok(Self {
on_upgrade,
extensions,
key: sec_websocket_protocol(key.as_bytes()),
})
}
}
}
pub(super) fn sec_websocket_protocol(key: &[u8]) -> String {
use base64::prelude::*;
let mut sha1 = Sha1::new();
sha1.update(key);
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); let result = sha1.finalize();
BASE64_STANDARD.encode(&result[..])
}
#[pin_project]
#[derive(Debug)]
pub struct UpgradeFut {
#[pin]
pub(super) inner: hyper::upgrade::OnUpgrade,
pub(super) negotiation: Option<Negotiation>,
}
impl std::future::Future for UpgradeFut {
type Output = hyper::Result<WebSocket<HttpStream>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
let upgraded = match this.inner.poll(cx) {
Poll::Ready(x) => x,
Poll::Pending => return Poll::Pending,
};
let io = TokioIo::new(upgraded?);
let negotiation = this.negotiation.take().unwrap();
Poll::Ready(Ok(WebSocket::new(
Role::Server,
HttpStream::from(io),
Bytes::new(),
negotiation,
)))
}
}