axum_h3/
lib.rs

1use std::future::Future;
2
3use axum::body::Bytes;
4use h3_util::{server::H3Acceptor, server_body::H3IncomingServer};
5use hyper::{Request, Response, body::Body, rt::Executor};
6
7/// Accept each connection from acceptor, then for each connection
8/// accept each request. Spawn a task to handle each request.
9async fn serve_inner<AC, F>(
10    svc: axum::Router,
11    executor: &h3_util::executor::SharedExec,
12    mut acceptor: AC,
13    signal: F,
14) -> Result<(), h3_util::Error>
15where
16    AC: H3Acceptor,
17    F: Future<Output = ()>,
18{
19    let svc = tower::ServiceBuilder::new()
20        //.add_extension(Arc::new(ConnInfo { addr, certificates }))
21        .service(svc);
22
23    // TODO: tonic body is wrapped? Is it for error to status conversion?
24    // use tower::ServiceExt;
25    // let h_svc =
26    //     hyper_util::service::TowerToHyperService::new(svc.map_request(|req: http::Request<_>| {
27    //         req.map(tonic::body::boxed::<crate::H3IncomingServer<AC::RS, Bytes>>)
28    //     }));
29
30    let h_svc = hyper_util::service::TowerToHyperService::new(svc);
31
32    let mut sig = std::pin::pin!(signal);
33    tracing::debug!("loop start");
34    loop {
35        tracing::debug!("loop");
36        // get the next stream to run http on
37        let conn = tokio::select! {
38            res = acceptor.accept() =>{
39                match res{
40                Ok(x) => x,
41                Err(e) => {
42                    tracing::error!("accept error : {e}");
43                    return Err(e);
44                }
45            }
46            }
47            _ = &mut sig =>{
48                tracing::debug!("cancellation triggered");
49                return Ok(());
50            }
51        };
52
53        let Some(conn) = conn else {
54            tracing::debug!("acceptor end of conn");
55            return Ok(());
56        };
57
58        // server each connection in the background
59        let h_svc_cp = h_svc.clone();
60        let executor_clone = executor.clone();
61        executor.execute(async move {
62            let mut conn = match h3::server::Connection::new(conn).await {
63                Ok(c) => c,
64                Err(e) => {
65                    tracing::debug!("server connection failed: {}", e);
66                    return;
67                }
68            };
69            loop {
70                let resolver = match conn.accept().await {
71                    Ok(req) => match req {
72                        Some(r) => r,
73                        None => {
74                            tracing::debug!("server connection ended:");
75                            break;
76                        }
77                    },
78                    Err(e) => {
79                        tracing::debug!("server connection accept failed: {}", e);
80                        break;
81                    }
82                };
83                let h_svc_cp = h_svc_cp.clone();
84                executor_clone.execute(async move {
85                    let (req, stream) = match resolver.resolve_request().await {
86                        Ok(req) => req,
87                        Err(e) => {
88                            tracing::debug!("fail resolve request {e:#?}");
89                            return;
90                        }
91                    };
92                    if let Err(e) = serve_request::<AC, _, _>(req, stream, h_svc_cp.clone()).await {
93                        tracing::debug!("server request failed: {}", e);
94                    }
95                });
96            }
97        });
98    }
99}
100
101async fn serve_request<AC, SVC, BD>(
102    request: Request<()>,
103    stream: h3::server::RequestStream<
104        <<AC as H3Acceptor>::CONN as h3::quic::OpenStreams<Bytes>>::BidiStream,
105        Bytes,
106    >,
107    service: SVC,
108) -> Result<(), h3_util::Error>
109where
110    AC: H3Acceptor,
111    SVC: hyper::service::Service<
112            Request<H3IncomingServer<AC::RS, Bytes>>,
113            Response = Response<BD>,
114            Error = std::convert::Infallible,
115        >,
116    SVC::Future: 'static,
117    BD: Body + 'static,
118    BD::Error: Into<h3_util::Error>,
119    <BD as Body>::Error: Into<h3_util::Error> + std::error::Error + Send + Sync,
120    <BD as Body>::Data: Send + Sync,
121{
122    tracing::debug!("serving request");
123    let (parts, _) = request.into_parts();
124    let (mut w, r) = stream.split();
125
126    let req = Request::from_parts(parts, H3IncomingServer::new(r));
127    tracing::debug!("serving request call service");
128    let res = service.call(req).await?;
129
130    let (res_h, res_b) = res.into_parts();
131
132    // write header
133    tracing::debug!("serving request write header");
134    w.send_response(Response::from_parts(res_h, ())).await?;
135
136    // write body or trailer.
137    h3_util::server_body::send_h3_server_body::<BD, AC::BS>(&mut w, res_b).await?;
138
139    tracing::debug!("serving request end");
140    Ok(())
141}
142
143pub struct H3Router {
144    inner: axum::Router,
145    executor: h3_util::executor::SharedExec, // expose this for the user.
146}
147
148impl H3Router {
149    pub fn new(inner: axum::Router) -> Self {
150        Self {
151            inner,
152            executor: h3_util::executor::SharedExec::tokio(),
153        }
154    }
155}
156
157impl From<axum::Router> for H3Router {
158    fn from(value: axum::Router) -> Self {
159        Self::new(value)
160    }
161}
162
163impl H3Router {
164    /// Runs the service on acceptor until shutdown.
165    pub async fn serve_with_shutdown<AC, F>(
166        self,
167        acceptor: AC,
168        signal: F,
169    ) -> Result<(), h3_util::Error>
170    where
171        AC: H3Acceptor,
172        F: Future<Output = ()>,
173    {
174        serve_inner(self.inner, &self.executor, acceptor, signal).await
175    }
176
177    /// Runs all services on acceptor
178    pub async fn serve<AC>(self, acceptor: AC) -> Result<(), h3_util::Error>
179    where
180        AC: H3Acceptor,
181    {
182        self.serve_with_shutdown(acceptor, async {
183            // never returns
184            futures::future::pending().await
185        })
186        .await
187    }
188}