use crate::server_runners::axum_tungstenite::rejection::*;
use crate::server_runners::axum_tungstenite::WebSocketUpgrade;
use crate::socket::SocketConfig;
use crate::Server;
use crate::ServerExt;
use crate::Socket;
use axum::extract::ConnectInfo;
use axum::extract::FromRequestParts;
use axum::response::Response;
use enfync::TryAdopt;
use http::request::Parts;
use std::net::SocketAddr;
#[derive(Debug)]
pub struct Upgrade {
ws: WebSocketUpgrade,
address: SocketAddr,
request: crate::Request,
}
impl Upgrade {
pub fn address(&self) -> &SocketAddr {
&self.address
}
pub fn request(&self) -> &crate::Request {
&self.request
}
}
impl<S> FromRequestParts<S> for Upgrade
where
S: Send + Sync,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let ConnectInfo(address) = parts
.extensions
.get::<ConnectInfo<SocketAddr>>()
.expect("Axum Server must be created with `axum::Router::into_make_service_with_connect_info::<SocketAddr, _>()`")
.to_owned();
let mut pure_req = crate::Request::builder()
.method(parts.method.clone())
.uri(parts.uri.clone())
.version(parts.version);
for (k, v) in parts.headers.iter() {
pure_req = pure_req.header(k, v);
}
let Ok(pure_req) = pure_req.body(()) else {
return Err(InvalidConnectionHeader {}.into());
};
Ok(Self {
ws: WebSocketUpgrade::from_request_parts(parts, state).await?,
address,
request: pure_req,
})
}
}
impl Upgrade {
pub fn on_upgrade<E: ServerExt + 'static>(self, server: Server<E>) -> Response {
self.on_upgrade_with_config(server, SocketConfig::default())
}
pub fn on_upgrade_with_config<E: ServerExt + 'static>(
self,
server: Server<E>,
socket_config: SocketConfig,
) -> Response {
self.ws.on_upgrade(move |socket| async move {
let handle = enfync::builtin::native::TokioHandle::try_adopt()
.expect("axum server runner only works in a tokio runtime");
let socket = Socket::new(socket, socket_config, handle);
server.accept(socket, self.request, self.address);
})
}
}