1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use axum::body::Body;
7use axum::extract::{Request as AxumRequest, State};
8use axum::response::IntoResponse;
9use axum::routing;
10use axum::{Extension, Router};
11
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::rt::{TokioExecutor, TokioTimer};
15
16use crate::backend::{BackendPool, BaseUri};
17use crate::balancing::{BalanceCtx, ProxyErrorKind, ResultEvent, RoundRobin, StartEvent};
18use crate::policy::Policy;
19use crate::proxy::{bad_gateway, gateway_timeout, proxy_request};
20use crate::route::{FixedRewrite, Rewrite, RewriteSpec, RouteMeta};
21use crate::routing::{NoRouteKey, RouteCtx};
22use crate::{Method, PartsCtx, RequestScope, Routes};
23
24struct Inner {
25 client: Client<HttpConnector, Body>,
26 state: Arc<http::Extensions>,
27 map_body_limit: usize,
28 request_timeout: Duration,
29}
30
31pub struct App {
32 router: Router,
33}
34
35pub struct AppBuilder {
36 backends: HashMap<String, Vec<String>>,
37 mounted: Vec<Routes>,
38 policies: HashMap<String, Arc<Policy>>,
39 default_policy: Arc<Policy>,
40 request_timeout: Duration,
41 connect_timeout: Duration,
42 pool_idle_timeout: Duration,
43 map_body_limit: usize,
44 state: http::Extensions,
45}
46
47impl AppBuilder {
48 pub fn new() -> Self {
49 Self {
50 backends: HashMap::new(),
51 mounted: Vec::new(),
52 policies: HashMap::new(),
53 default_policy: Arc::new(Policy::new().router(NoRouteKey).balancer(RoundRobin::new())),
54 request_timeout: Duration::from_secs(30),
55 connect_timeout: Duration::from_secs(5),
56 pool_idle_timeout: Duration::from_secs(90),
57 map_body_limit: 2 * 1024 * 1024,
58 state: http::Extensions::new(),
59 }
60 }
61
62 pub fn backend<I, S>(mut self, service: &str, urls: I) -> Self
63 where
64 I: IntoIterator<Item = S>,
65 S: Into<String>,
66 {
67 self.backends.insert(
68 service.to_string(),
69 urls.into_iter().map(|s| s.into()).collect(),
70 );
71 self
72 }
73
74 pub fn policy(mut self, name: &str, policy: Policy) -> Self {
75 self.policies.insert(name.to_string(), Arc::new(policy));
76 self
77 }
78
79 pub fn default_policy(mut self, policy: Policy) -> Self {
80 self.default_policy = Arc::new(policy);
81 self
82 }
83
84 pub fn request_timeout(mut self, d: Duration) -> Self {
85 self.request_timeout = d;
86 self
87 }
88
89 pub fn connect_timeout(mut self, d: Duration) -> Self {
90 self.connect_timeout = d;
91 self
92 }
93
94 pub fn pool_idle_timeout(mut self, d: Duration) -> Self {
95 self.pool_idle_timeout = d;
96 self
97 }
98
99 pub fn map_body_limit(mut self, bytes: usize) -> Self {
100 self.map_body_limit = bytes;
101 self
102 }
103
104 pub fn state<T: Clone + Send + Sync + 'static>(mut self, val: T) -> Self {
105 self.state.insert(val);
106 self
107 }
108
109 pub fn mount(mut self, routes: Routes) -> Self {
110 self.mounted.push(routes);
111 self
112 }
113
114 pub fn build(self) -> Result<App, String> {
115 let mut connector = HttpConnector::new();
117 connector.set_nodelay(true);
118 connector.set_connect_timeout(Some(self.connect_timeout));
119 connector.set_keepalive(Some(self.pool_idle_timeout));
120
121 let client = Client::builder(TokioExecutor::new())
122 .pool_timer(TokioTimer::new())
123 .pool_idle_timeout(self.pool_idle_timeout)
124 .build(connector);
125
126 let pools: HashMap<_, _> = self
128 .backends
129 .into_iter()
130 .map(|(svc, urls)| {
131 let bases = urls
132 .iter()
133 .map(|u| BaseUri::parse(u))
134 .collect::<Result<Vec<_>, _>>()?;
135 Ok((svc, Arc::new(BackendPool::new(bases))))
136 })
137 .collect::<Result<_, String>>()?;
138
139 let inner = Arc::new(Inner {
140 client,
141 state: Arc::new(self.state),
142 request_timeout: self.request_timeout,
143 map_body_limit: self.map_body_limit,
144 });
145
146 let mut router = Router::new();
148
149 for svc_routes in self.mounted {
150 let pool = pools
151 .get(svc_routes.service)
152 .ok_or_else(|| {
153 format!(
154 "backend for service `{}` is not registered",
155 svc_routes.service,
156 )
157 })?
158 .clone();
159
160 router = mount_service(
161 router,
162 svc_routes,
163 &self.policies,
164 self.default_policy.clone(),
165 pool,
166 )?;
167 }
168
169 let router = router.with_state(inner);
170
171 Ok(App { router })
172 }
173}
174
175impl App {
176 pub fn builder() -> AppBuilder {
177 AppBuilder::new()
178 }
179}
180
181pub async fn run(addr: SocketAddr, app: App) -> std::io::Result<()> {
182 let listener = tokio::net::TcpListener::bind(addr).await?;
183 axum::serve(listener, app.router).await
185}
186
187fn mount_service(
188 mut router: Router<Arc<Inner>>,
189 routes: Routes,
190 policies: &HashMap<String, Arc<Policy>>,
191 default_policy: Arc<Policy>,
192 pool: Arc<BackendPool>,
193) -> Result<Router<Arc<Inner>>, String> {
194 for rd in routes.routes {
206 let full_path = join(routes.prefix, rd.path);
207
208 let policy = resolve_policy(routes.policy, rd.policy, policies, &default_policy)?;
209
210 let meta = Arc::new(RouteMeta {
211 service: routes.service,
212 route_path: rd.path,
213 prefix: routes.prefix,
214 rewrite: match rd.rewrite {
215 RewriteSpec::StripPrefix => Rewrite::StripPrefix,
216 RewriteSpec::Static(to) => Rewrite::Static(FixedRewrite::new(to)),
217 RewriteSpec::Template(tpl) => Rewrite::Template(tpl),
218 },
219 pool: Arc::clone(&pool),
220 policy,
221 pipeline: rd.pipeline,
222 });
223
224 let method_router = method_router(rd.method).layer(Extension(meta));
225
226 router = router.route(&full_path, method_router);
227 }
228
229 Ok(router)
230}
231
232fn resolve_policy(
233 service_policy: Option<&'static str>,
234 route_policy: Option<&'static str>,
235 registry: &HashMap<String, Arc<Policy>>,
236 default_policy: &Arc<Policy>,
237) -> Result<Arc<Policy>, String> {
238 let effective = route_policy.or(service_policy);
239
240 match effective {
241 Some(name) => registry
242 .get(name)
243 .cloned()
244 .ok_or_else(|| format!("policy `{name}` is not registered")),
245 None => Ok(default_policy.clone()),
246 }
247}
248
249fn join(prefix: &str, path: &str) -> String {
250 let mut s = String::with_capacity(prefix.len() + path.len());
253 if prefix.ends_with('/') {
254 s.push_str(prefix.trim_end_matches('/'));
255 } else {
256 s.push_str(prefix);
257 }
258 if path.starts_with('/') {
259 s.push_str(path);
260 } else {
261 s.push('/');
262 s.push_str(path);
263 }
264 s
265}
266
267fn method_router(method: Method) -> routing::MethodRouter<Arc<Inner>> {
268 match method {
269 Method::Get => routing::get(proxy_handler),
270 Method::Post => routing::post(proxy_handler),
271 Method::Put => routing::put(proxy_handler),
272 Method::Delete => routing::delete(proxy_handler),
273 Method::Patch => routing::patch(proxy_handler),
274 Method::Head => routing::head(proxy_handler),
275 Method::Options => routing::options(proxy_handler),
276 }
277}
278
279async fn proxy_handler(
280 State(inner): State<Arc<Inner>>,
281 Extension(meta): Extension<Arc<RouteMeta>>,
282 req: AxumRequest,
283) -> axum::response::Response {
284 let pool = &meta.pool;
285 let (mut parts, body) = req.into_parts();
286
287 let body = if let Some(pipeline) = meta.pipeline {
289 let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
290 let scope = RequestScope::with_shared(Arc::clone(&inner.state), body, inner.map_body_limit);
291
292 match pipeline(ctx, scope).await {
293 Ok(body) => body,
294 Err(err) => return err.into_response(),
295 }
296 } else {
297 body
298 };
299
300 let route_ctx = RouteCtx {
302 service: meta.service,
303 prefix: meta.prefix,
304 route_path: meta.route_path,
305 method: &parts.method,
306 uri: &parts.uri,
307 headers: &parts.headers,
308 };
309 let routing = meta.policy.router.route(&route_ctx, pool);
310
311 let balance_ctx = BalanceCtx {
313 service: meta.service,
314 affinity: routing.affinity.as_ref(),
315 pool,
316 candidates: routing.candidates,
317 };
318 let Some(backend_index) = meta.policy.balancer.pick(&balance_ctx) else {
319 return bad_gateway("no backends selected by balancer");
320 };
321 let Some(backend) = pool.get(backend_index) else {
322 return bad_gateway("balancer returned invalid backend index");
323 };
324
325 meta.policy.balancer.on_start(&StartEvent {
327 service: meta.service,
328 backend_index,
329 });
330
331 let started_at = Instant::now();
332
333 let result = tokio::time::timeout(
334 inner.request_timeout,
335 proxy_request(backend, &inner.client, &meta, parts, body),
336 )
337 .await
338 .unwrap_or_else(|_| Err(ProxyErrorKind::Timeout));
339
340 match result {
341 Ok(response) => {
342 let elapsed = started_at.elapsed();
343
344 meta.policy.balancer.on_result(&ResultEvent {
345 service: meta.service,
346 backend_index,
347 status: Some(response.status()),
348 error: None,
349 head_latency: elapsed,
350 });
351
352 response
353 }
354 Err(error) => {
355 let elapsed = started_at.elapsed();
356
357 meta.policy.balancer.on_result(&ResultEvent {
358 service: meta.service,
359 backend_index,
360 status: None,
361 error: Some(error),
362 head_latency: elapsed,
363 });
364
365 match error {
366 ProxyErrorKind::NoBackends => bad_gateway("no backends"),
367 ProxyErrorKind::InvalidUpstreamUri => bad_gateway("bad upstream uri"),
368 ProxyErrorKind::UpstreamRequest => bad_gateway("upstream request failed"),
369 ProxyErrorKind::Timeout => gateway_timeout("upstream request timed out"),
370 }
371 }
372 }
373}