use std::{fmt, rc::Rc};
pub use crate::ws::{CloseCode, CloseReason, Frame, Message, WsSink};
use crate::http::{StatusCode, body::BodySize, h1, header};
use crate::io::{DispatchItem, IoConfig, Reason};
use crate::service::{
IntoServiceFactory, Service, ServiceCtx, ServiceFactory, chain_factory,
fn_factory_with_config,
};
use crate::web::{HttpRequest, HttpResponse};
use crate::ws::{self, error::HandshakeError, error::WsError, handshake};
use crate::{SharedCfg, rt, time::Seconds};
thread_local! {
static CFG: SharedCfg = SharedCfg::new("WS")
.add(IoConfig::new().set_keepalive_timeout(Seconds::ZERO))
.into();
}
pub fn subprotocols(req: &HttpRequest) -> impl Iterator<Item = &str> {
req.headers()
.get_all(header::SEC_WEBSOCKET_PROTOCOL)
.flat_map(|val| {
val.to_str()
.ok()
.into_iter()
.flat_map(|s| s.split(',').map(str::trim).filter(|s| !s.is_empty()))
})
}
pub async fn start<T, F, P, Err>(
req: HttpRequest,
subprotocol: Option<P>,
factory: F,
) -> Result<HttpResponse, Err>
where
T: ServiceFactory<Frame, WsSink, Response = Option<Message>> + 'static,
T::Error: fmt::Debug,
F: IntoServiceFactory<T, Frame, WsSink>,
P: AsRef<str>,
Err: From<T::InitError> + From<HandshakeError>,
{
let inner_factory = Rc::new(chain_factory(factory).map_err(WsError::Service));
let factory = fn_factory_with_config(async move |sink: WsSink| {
let srv = inner_factory.create(sink.clone()).await?;
let sink = sink.clone();
Ok::<_, T::InitError>(DispatchService { srv, sink })
});
start_with(req, subprotocol, factory).await
}
pub async fn start_with<T, F, P, Err>(
req: HttpRequest,
subprotocol: Option<P>,
factory: F,
) -> Result<HttpResponse, Err>
where
T: ServiceFactory<DispatchItem<ws::Codec>, WsSink, Response = Option<Message>>
+ 'static,
T::Error: fmt::Debug,
F: IntoServiceFactory<T, DispatchItem<ws::Codec>, WsSink>,
P: AsRef<str>,
Err: From<T::InitError> + From<HandshakeError>,
{
log::trace!("Start ws handshake verification for {:?}", req.path());
let mut res = handshake(req.head())?;
if let Some(protocol) = subprotocol {
res.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocol.as_ref());
}
let res = res.finish().into_parts().0;
let item = req
.head()
.take_io()
.ok_or(HandshakeError::NoWebsocketUpgrade)?;
let io = item.0;
let codec = item.1;
io.encode(h1::Message::Item((res, BodySize::Empty)), &codec)
.map_err(|_| HandshakeError::NoWebsocketUpgrade)?;
log::trace!("Ws handshake verification completed for {:?}", req.path());
let codec = ws::Codec::new();
let sink = WsSink::new(io.get_ref(), codec.clone());
let srv = factory.into_factory().create(sink.clone()).await?;
io.set_config(CFG.with(Clone::clone));
io.stop_timer();
rt::spawn(async move {
let res = crate::io::Dispatcher::new(io, codec, srv).await;
log::trace!("Ws handler is terminated: {res:?}");
});
Ok(HttpResponse::new(StatusCode::OK))
}
struct DispatchService<S> {
srv: S,
sink: WsSink,
}
impl<S, E> Service<DispatchItem<ws::Codec>> for DispatchService<S>
where
S: Service<Frame, Response = Option<Message>, Error = WsError<E>>,
E: fmt::Debug,
{
type Response = Option<Message>;
type Error = WsError<E>;
crate::forward_ready!(srv);
crate::forward_poll!(srv);
crate::forward_shutdown!(srv);
async fn call(
&self,
req: DispatchItem<ws::Codec>,
ctx: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
match req {
DispatchItem::Item(item) => {
let s = if matches!(item, Frame::Close(_)) {
Some(self.sink.clone())
} else {
None
};
let result = ctx.call(&self.srv, item).await;
if let Some(s) = s {
rt::spawn(async move { s.io().close() });
}
result
}
DispatchItem::Control(_) => Ok(None),
DispatchItem::Stop(Reason::KeepAliveTimeout) => Err(WsError::KeepAlive),
DispatchItem::Stop(Reason::ReadTimeout) => Err(WsError::ReadTimeout),
DispatchItem::Stop(Reason::Decoder(e) | Reason::Encoder(e)) => {
Err(WsError::Protocol(e))
}
DispatchItem::Stop(Reason::Io(e)) => Err(WsError::Disconnected(e)),
}
}
}