use alloc::{string::ToString, sync::Arc, vec::Vec};
use core::time::Duration;
use async_lock::Mutex;
use http_body_util::{BodyExt, Full, Limited};
use hyper::{
Method, Request, Response, StatusCode,
body::{Bytes, Incoming},
};
use subduction_core::{
authenticated::Authenticated,
handshake::{self, audience::Audience},
nonce_cache::NonceCache,
peer::id::PeerId,
timeout::Timeout,
timestamp::TimestampSeconds,
};
use subduction_crypto::signer::Signer;
use future_form::{FutureForm, Sendable};
use futures::{FutureExt, future::BoxFuture};
use crate::{
DEFAULT_MAX_BODY_SIZE, DEFAULT_POLL_TIMEOUT_SECS, SESSION_ID_HEADER,
error::ServerError,
session::{SessionEntry, SessionId, SessionStore},
transport::HttpLongPollTransport,
};
#[derive(Debug, Clone)]
pub struct LongPollHandler<Sig, O: Timeout<Sendable> + Send + Sync> {
sessions: SessionStore,
signer: Sig,
nonce_cache: Arc<NonceCache>,
our_peer_id: PeerId,
discovery_audience: Option<Audience>,
handshake_max_drift: Duration,
timeout: O,
max_body_size: usize,
poll_timeout: Duration,
}
impl<Sig: Signer<Sendable> + Clone + Send + Sync, O: Timeout<Sendable> + Clone + Send + Sync>
LongPollHandler<Sig, O>
{
#[must_use]
pub fn new(
signer: Sig,
nonce_cache: Arc<NonceCache>,
our_peer_id: PeerId,
discovery_audience: Option<Audience>,
handshake_max_drift: Duration,
timeout: O,
) -> Self {
Self {
sessions: SessionStore::new(),
signer,
nonce_cache,
our_peer_id,
discovery_audience,
handshake_max_drift,
timeout,
max_body_size: DEFAULT_MAX_BODY_SIZE,
poll_timeout: Duration::from_secs(DEFAULT_POLL_TIMEOUT_SECS),
}
}
#[must_use]
pub const fn with_max_body_size(mut self, size: usize) -> Self {
self.max_body_size = size;
self
}
#[must_use]
pub const fn with_poll_timeout(mut self, timeout: Duration) -> Self {
self.poll_timeout = timeout;
self
}
#[must_use]
pub const fn sessions(&self) -> &SessionStore {
&self.sessions
}
pub async fn handle(
&self,
req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, ServerError> {
let method = req.method().clone();
let path = req.uri().path();
tracing::debug!("HTTP long-poll: {method} {path}");
let response = match (&method, path) {
(&Method::POST, "/lp/handshake") => self.handle_handshake(req).await,
(&Method::POST, "/lp/send") => self.handle_send(req).await,
(&Method::POST, "/lp/recv") => self.handle_recv(req).await,
(&Method::POST, "/lp/disconnect") => self.handle_disconnect(req).await,
_ => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from_static(b"not found")))
.map_err(ServerError::from),
};
match response {
Ok(resp) => Ok(resp),
Err(e) => {
tracing::error!("handler error: {e}");
Ok(error_response(StatusCode::INTERNAL_SERVER_ERROR, &e)?)
}
}
}
async fn handle_handshake(
&self,
req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, ServerError> {
let body = read_body(req, self.max_body_size).await?;
let response_slot: Arc<Mutex<Option<Vec<u8>>>> = Arc::new(Mutex::new(None));
let http_handshake = HttpHandshake {
challenge_bytes: Some(body),
response_slot: response_slot.clone(),
};
let now = TimestampSeconds::now();
let result = handshake::respond::<Sendable, _, _, _, _>(
http_handshake,
|_handshake, peer_id| {
let conn = HttpLongPollTransport::new(peer_id);
(conn.clone(), conn)
},
&self.signer,
&self.nonce_cache,
self.our_peer_id,
self.discovery_audience,
now,
self.handshake_max_drift,
)
.await;
let response_bytes = response_slot.lock().await.take();
match result {
Ok((authenticated, conn)) => {
let peer_id = authenticated.peer_id();
let session_id = SessionId::random();
tracing::info!(
"HTTP long-poll handshake complete: peer {peer_id}, session {session_id}"
);
self.sessions
.insert(
session_id,
SessionEntry {
peer_id,
connection: conn.clone(),
authenticated: Some(authenticated),
},
)
.await;
let response_bytes = response_bytes.ok_or(ServerError::HandshakeNoResponse)?;
Ok(Response::builder()
.status(StatusCode::OK)
.header(SESSION_ID_HEADER, session_id.to_hex())
.header("content-type", "application/octet-stream")
.body(Full::new(Bytes::from(response_bytes)))?)
}
Err(e) => {
tracing::warn!("handshake failed: {e}");
if let Some(response_bytes) = response_bytes {
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("content-type", "application/octet-stream")
.body(Full::new(Bytes::from(response_bytes)))?)
} else {
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from(e.to_string())))?)
}
}
}
}
async fn handle_send(
&self,
req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, ServerError> {
let session_id = extract_session_id(&req)?;
let entry = self
.sessions
.get(&session_id)
.await
.ok_or(ServerError::SessionNotFound)?;
let body = read_body(req, self.max_body_size).await?;
tracing::debug!(
"POST /lp/send: peer {} ({} bytes)",
entry.peer_id,
body.len()
);
entry
.connection
.push_inbound(body)
.await
.map_err(|_| ServerError::ChanSend)?;
Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::new()))?)
}
async fn handle_recv(
&self,
req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, ServerError> {
let session_id = extract_session_id(&req)?;
let entry = self
.sessions
.get(&session_id)
.await
.ok_or(ServerError::SessionNotFound)?;
tracing::debug!("POST /lp/recv: peer {} waiting...", entry.peer_id);
let pull_fut = Sendable::from_future(async move { entry.connection.pull_outbound().await });
match self.timeout.timeout(self.poll_timeout, pull_fut).await {
Ok(Ok(bytes)) => {
tracing::debug!("POST /lp/recv: delivering {} bytes", bytes.len());
Ok(Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/octet-stream")
.body(Full::new(Bytes::from(bytes)))?)
}
Ok(Err(_)) => {
tracing::debug!("POST /lp/recv: channel closed");
Ok(Response::builder()
.status(StatusCode::GONE)
.body(Full::new(Bytes::from_static(b"session closed")))?)
}
Err(_timed_out) => {
tracing::debug!("POST /lp/recv: poll timeout");
Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::new()))?)
}
}
}
async fn handle_disconnect(
&self,
req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, ServerError> {
let session_id = extract_session_id(&req)?;
if let Some(entry) = self.sessions.remove(&session_id).await {
tracing::info!(
"POST /lp/disconnect: peer {} session {session_id}",
entry.peer_id
);
entry.connection.close();
}
Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::new()))?)
}
pub async fn take_authenticated(
&self,
session_id: &SessionId,
) -> Option<Authenticated<HttpLongPollTransport, Sendable>> {
let mut sessions = self.sessions.sessions.lock().await;
sessions
.get_mut(session_id)
.and_then(|entry| entry.authenticated.take())
}
}
struct HttpHandshake {
challenge_bytes: Option<Vec<u8>>,
response_slot: Arc<Mutex<Option<Vec<u8>>>>,
}
impl subduction_core::handshake::Handshake<Sendable> for HttpHandshake {
type Error = ServerError;
fn send(&mut self, bytes: Vec<u8>) -> BoxFuture<'_, Result<(), Self::Error>> {
let slot = self.response_slot.clone();
async move {
*slot.lock().await = Some(bytes);
Ok(())
}
.boxed()
}
fn recv(&mut self) -> BoxFuture<'_, Result<Vec<u8>, Self::Error>> {
let bytes = self.challenge_bytes.take();
async move { bytes.ok_or(ServerError::HandshakeNoChallenge) }.boxed()
}
}
fn extract_session_id<T>(req: &Request<T>) -> Result<SessionId, ServerError> {
let header_value = req
.headers()
.get(SESSION_ID_HEADER)
.ok_or(ServerError::InvalidSessionId)?;
let header_str = header_value
.to_str()
.map_err(|_| ServerError::InvalidSessionId)?;
SessionId::from_hex(header_str).ok_or(ServerError::InvalidSessionId)
}
async fn read_body(req: Request<Incoming>, max_size: usize) -> Result<Vec<u8>, ServerError> {
let limited = Limited::new(req.into_body(), max_size);
let collected = limited
.collect()
.await
.map_err(|_| ServerError::BodyTooLarge)?;
Ok(collected.to_bytes().to_vec())
}
fn error_response(
status: StatusCode,
err: &ServerError,
) -> Result<Response<Full<Bytes>>, hyper::http::Error> {
Response::builder()
.status(status)
.body(Full::new(Bytes::from(err.to_string())))
}