axum_h3/
lib.rs

1use std::future::Future;
2
3use axum::body::Bytes;
4use h3_util::{server::H3Acceptor, server_body::H3IncomingServer};
5use hyper::{Request, Response, body::Body};
6
7/// Accept each connection from acceptor, then for each connection
8/// accept each request. Spawn a task to handle each request.
9async fn serve_inner<AC, F>(
10    svc: axum::Router,
11    mut acceptor: AC,
12    signal: F,
13) -> Result<(), h3_util::Error>
14where
15    AC: H3Acceptor,
16    F: Future<Output = ()>,
17{
18    let svc = tower::ServiceBuilder::new()
19        //.add_extension(Arc::new(ConnInfo { addr, certificates }))
20        .service(svc);
21
22    // TODO: tonic body is wrapped? Is it for error to status conversion?
23    // use tower::ServiceExt;
24    // let h_svc =
25    //     hyper_util::service::TowerToHyperService::new(svc.map_request(|req: http::Request<_>| {
26    //         req.map(tonic::body::boxed::<crate::H3IncomingServer<AC::RS, Bytes>>)
27    //     }));
28
29    let h_svc = hyper_util::service::TowerToHyperService::new(svc);
30
31    let mut sig = std::pin::pin!(signal);
32    tracing::debug!("loop start");
33    loop {
34        tracing::debug!("loop");
35        // get the next stream to run http on
36        let conn = tokio::select! {
37            res = acceptor.accept() =>{
38                match res{
39                Ok(x) => x,
40                Err(e) => {
41                    tracing::error!("accept error : {e}");
42                    return Err(e);
43                }
44            }
45            }
46            _ = &mut sig =>{
47                tracing::debug!("cancellation triggered");
48                return Ok(());
49            }
50        };
51
52        let Some(conn) = conn else {
53            tracing::debug!("acceptor end of conn");
54            return Ok(());
55        };
56
57        let h_svc_cp = h_svc.clone();
58        tokio::spawn(async move {
59            let mut conn = match h3::server::Connection::new(conn).await {
60                Ok(c) => c,
61                Err(e) => {
62                    tracing::debug!("server connection failed: {}", e);
63                    return;
64                }
65            };
66            loop {
67                let resolver = match conn.accept().await {
68                    Ok(req) => match req {
69                        Some(r) => r,
70                        None => {
71                            tracing::debug!("server connection ended:");
72                            break;
73                        }
74                    },
75                    Err(e) => {
76                        tracing::debug!("server connection accept failed: {}", e);
77                        break;
78                    }
79                };
80                let h_svc_cp = h_svc_cp.clone();
81                tokio::spawn(async move {
82                    let (req, stream) = match resolver.resolve_request().await {
83                        Ok(req) => req,
84                        Err(e) => {
85                            tracing::debug!("fail resolve request {e:#?}");
86                            return;
87                        }
88                    };
89                    if let Err(e) = serve_request::<AC, _, _>(req, stream, h_svc_cp.clone()).await {
90                        tracing::debug!("server request failed: {}", e);
91                    }
92                });
93            }
94        });
95    }
96}
97
98async fn serve_request<AC, SVC, BD>(
99    request: Request<()>,
100    stream: h3::server::RequestStream<
101        <<AC as H3Acceptor>::CONN as h3::quic::OpenStreams<Bytes>>::BidiStream,
102        Bytes,
103    >,
104    service: SVC,
105) -> Result<(), h3_util::Error>
106where
107    AC: H3Acceptor,
108    SVC: hyper::service::Service<
109            Request<H3IncomingServer<AC::RS, Bytes>>,
110            Response = Response<BD>,
111            Error = std::convert::Infallible,
112        >,
113    SVC::Future: 'static,
114    BD: Body + 'static,
115    BD::Error: Into<h3_util::Error>,
116    <BD as Body>::Error: Into<h3_util::Error> + std::error::Error + Send + Sync,
117    <BD as Body>::Data: Send + Sync,
118{
119    tracing::debug!("serving request");
120    let (parts, _) = request.into_parts();
121    let (mut w, r) = stream.split();
122
123    let req = Request::from_parts(parts, H3IncomingServer::new(r));
124    tracing::debug!("serving request call service");
125    let res = service.call(req).await?;
126
127    let (res_h, res_b) = res.into_parts();
128
129    // write header
130    tracing::debug!("serving request write header");
131    w.send_response(Response::from_parts(res_h, ())).await?;
132
133    // write body or trailer.
134    h3_util::server_body::send_h3_server_body::<BD, AC::BS>(&mut w, res_b).await?;
135
136    tracing::debug!("serving request end");
137    Ok(())
138}
139
140pub struct H3Router(axum::Router);
141
142impl H3Router {
143    pub fn new(inner: axum::Router) -> Self {
144        Self(inner)
145    }
146}
147
148impl From<axum::Router> for H3Router {
149    fn from(value: axum::Router) -> Self {
150        Self::new(value)
151    }
152}
153
154impl H3Router {
155    /// Runs the service on acceptor until shutdown.
156    pub async fn serve_with_shutdown<AC, F>(
157        self,
158        acceptor: AC,
159        signal: F,
160    ) -> Result<(), h3_util::Error>
161    where
162        AC: H3Acceptor,
163        F: Future<Output = ()>,
164    {
165        serve_inner(self.0, acceptor, signal).await
166    }
167
168    /// Runs all services on acceptor
169    pub async fn serve<AC>(self, acceptor: AC) -> Result<(), h3_util::Error>
170    where
171        AC: H3Acceptor,
172    {
173        self.serve_with_shutdown(acceptor, async {
174            // never returns
175            futures::future::pending().await
176        })
177        .await
178    }
179}