use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use oxihttp_core::OxiHttpError;
use oxiquic_h3::{H3Connection, H3ServerBuilder, H3ServerEndpoint};
pub use oxiquic_h3::{H3Request, H3Response};
pub struct H3Server {
endpoint: H3ServerEndpoint,
}
impl H3Server {
pub async fn bind(
addr: SocketAddr,
tls_config: rustls::ServerConfig,
) -> Result<Self, OxiHttpError> {
let endpoint = H3ServerBuilder::new(addr)
.with_tls_config(tls_config)
.build()
.await
.map_err(|e| OxiHttpError::H3(e.to_string()))?;
Ok(Self { endpoint })
}
pub fn local_addr(&self) -> Result<SocketAddr, OxiHttpError> {
self.endpoint
.local_addr()
.map_err(|e| OxiHttpError::H3(e.to_string()))
}
pub async fn serve<F, Fut>(self, handler: F) -> Result<(), OxiHttpError>
where
F: Fn(H3Request, Bytes) -> Fut + Send + Sync + 'static,
Fut: Future<Output = H3Response> + Send + 'static,
{
let handler = Arc::new(handler);
let incoming = self.endpoint.incoming();
while let Some(conn) = incoming.next().await {
let handler = Arc::clone(&handler);
tokio::spawn(handle_h3_connection(conn, handler));
}
Ok(())
}
}
async fn handle_h3_connection<F, Fut>(mut conn: H3Connection, handler: Arc<F>)
where
F: Fn(H3Request, Bytes) -> Fut + Send + Sync + 'static,
Fut: Future<Output = H3Response> + Send + 'static,
{
while let Ok(Some(mut req_ctx)) = conn.accept().await {
let handler = Arc::clone(&handler);
tokio::spawn(async move {
let req = req_ctx.request().clone();
let body = match req_ctx.body().await {
Ok(b) => b,
Err(_) => return,
};
let resp = handler(req, body).await;
let _ = req_ctx.respond(resp).await;
});
}
}