1use std::future::Future;
2
3use axum::body::Bytes;
4use h3_util::{server::H3Acceptor, server_body::H3IncomingServer};
5use hyper::{Request, Response, body::Body};
6
7async fn serve_inner<AC, F>(
10 svc: axum::Router,
11 mut acceptor: AC,
12 signal: F,
13) -> Result<(), h3_util::Error>
14where
15 AC: H3Acceptor,
16 F: Future<Output = ()>,
17{
18 let svc = tower::ServiceBuilder::new()
19 .service(svc);
21
22 let h_svc = hyper_util::service::TowerToHyperService::new(svc);
30
31 let mut sig = std::pin::pin!(signal);
32 tracing::debug!("loop start");
33 loop {
34 tracing::debug!("loop");
35 let conn = tokio::select! {
37 res = acceptor.accept() =>{
38 match res{
39 Ok(x) => x,
40 Err(e) => {
41 tracing::error!("accept error : {e}");
42 return Err(e);
43 }
44 }
45 }
46 _ = &mut sig =>{
47 tracing::debug!("cancellation triggered");
48 return Ok(());
49 }
50 };
51
52 let Some(conn) = conn else {
53 tracing::debug!("acceptor end of conn");
54 return Ok(());
55 };
56
57 let h_svc_cp = h_svc.clone();
58 tokio::spawn(async move {
59 let mut conn = match h3::server::Connection::new(conn).await {
60 Ok(c) => c,
61 Err(e) => {
62 tracing::debug!("server connection failed: {}", e);
63 return;
64 }
65 };
66 loop {
67 let resolver = match conn.accept().await {
68 Ok(req) => match req {
69 Some(r) => r,
70 None => {
71 tracing::debug!("server connection ended:");
72 break;
73 }
74 },
75 Err(e) => {
76 tracing::debug!("server connection accept failed: {}", e);
77 break;
78 }
79 };
80 let h_svc_cp = h_svc_cp.clone();
81 tokio::spawn(async move {
82 let (req, stream) = match resolver.resolve_request().await {
83 Ok(req) => req,
84 Err(e) => {
85 tracing::debug!("fail resolve request {e:#?}");
86 return;
87 }
88 };
89 if let Err(e) = serve_request::<AC, _, _>(req, stream, h_svc_cp.clone()).await {
90 tracing::debug!("server request failed: {}", e);
91 }
92 });
93 }
94 });
95 }
96}
97
98async fn serve_request<AC, SVC, BD>(
99 request: Request<()>,
100 stream: h3::server::RequestStream<
101 <<AC as H3Acceptor>::CONN as h3::quic::OpenStreams<Bytes>>::BidiStream,
102 Bytes,
103 >,
104 service: SVC,
105) -> Result<(), h3_util::Error>
106where
107 AC: H3Acceptor,
108 SVC: hyper::service::Service<
109 Request<H3IncomingServer<AC::RS, Bytes>>,
110 Response = Response<BD>,
111 Error = std::convert::Infallible,
112 >,
113 SVC::Future: 'static,
114 BD: Body + 'static,
115 BD::Error: Into<h3_util::Error>,
116 <BD as Body>::Error: Into<h3_util::Error> + std::error::Error + Send + Sync,
117 <BD as Body>::Data: Send + Sync,
118{
119 tracing::debug!("serving request");
120 let (parts, _) = request.into_parts();
121 let (mut w, r) = stream.split();
122
123 let req = Request::from_parts(parts, H3IncomingServer::new(r));
124 tracing::debug!("serving request call service");
125 let res = service.call(req).await?;
126
127 let (res_h, res_b) = res.into_parts();
128
129 tracing::debug!("serving request write header");
131 w.send_response(Response::from_parts(res_h, ())).await?;
132
133 h3_util::server_body::send_h3_server_body::<BD, AC::BS>(&mut w, res_b).await?;
135
136 tracing::debug!("serving request end");
137 Ok(())
138}
139
140pub struct H3Router(axum::Router);
141
142impl H3Router {
143 pub fn new(inner: axum::Router) -> Self {
144 Self(inner)
145 }
146}
147
148impl From<axum::Router> for H3Router {
149 fn from(value: axum::Router) -> Self {
150 Self::new(value)
151 }
152}
153
154impl H3Router {
155 pub async fn serve_with_shutdown<AC, F>(
157 self,
158 acceptor: AC,
159 signal: F,
160 ) -> Result<(), h3_util::Error>
161 where
162 AC: H3Acceptor,
163 F: Future<Output = ()>,
164 {
165 serve_inner(self.0, acceptor, signal).await
166 }
167
168 pub async fn serve<AC>(self, acceptor: AC) -> Result<(), h3_util::Error>
170 where
171 AC: H3Acceptor,
172 {
173 self.serve_with_shutdown(acceptor, async {
174 futures::future::pending().await
176 })
177 .await
178 }
179}