1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use axum::body::Body;
7use axum::extract::{Request as AxumRequest, State};
8use axum::response::IntoResponse;
9use axum::routing;
10use axum::{Extension, Router};
11
12use hyper_util::client::legacy::Client;
13use hyper_util::rt::TokioExecutor;
14
15use crate::backend::{BackendPool, BaseUri};
16use crate::balancing::{BalanceCtx, ProxyErrorKind, ResultEvent, RoundRobin, StartEvent};
17use crate::policy::Policy;
18use crate::proxy::{bad_gateway, proxy_request};
19use crate::route::{FixedRewrite, Rewrite, RewriteSpec, RouteMeta};
20use crate::routing::{NoRouteKey, RouteCtx};
21use crate::{Method, PartsCtx, Routes};
22
23struct Inner {
24 backends: HashMap<String, BackendPool>,
25 client: Client<hyper_util::client::legacy::connect::HttpConnector, Body>,
26 map_body_limit: usize,
27 _default_timeout: Duration, }
29
30pub struct App {
31 router: Router,
32}
33
34pub struct AppBuilder {
35 backends: HashMap<String, Vec<String>>,
36 mounted: Vec<Routes>,
37 policies: HashMap<String, Arc<Policy>>,
38 default_policy: Arc<Policy>,
39 default_timeout: Duration,
40 map_body_limit: usize,
41}
42
43impl AppBuilder {
44 pub fn new() -> Self {
45 Self {
46 backends: HashMap::new(),
47 mounted: Vec::new(),
48 policies: HashMap::new(),
49 default_policy: Arc::new(Policy::new().router(NoRouteKey).balancer(RoundRobin::new())),
50 default_timeout: Duration::from_secs(30),
51 map_body_limit: 2 * 1024 * 1024,
52 }
53 }
54
55 pub fn backend<I, S>(mut self, service: &str, urls: I) -> Self
56 where
57 I: IntoIterator<Item = S>,
58 S: Into<String>,
59 {
60 self.backends.insert(
61 service.to_string(),
62 urls.into_iter().map(|s| s.into()).collect(),
63 );
64 self
65 }
66
67 pub fn policy(mut self, name: &str, policy: Policy) -> Self {
68 self.policies.insert(name.to_string(), Arc::new(policy));
69 self
70 }
71
72 pub fn default_policy(mut self, policy: Policy) -> Self {
73 self.default_policy = Arc::new(policy);
74 self
75 }
76
77 pub fn default_timeout(mut self, d: Duration) -> Self {
78 self.default_timeout = d;
79 self
80 }
81
82 pub fn map_body_limit(mut self, bytes: usize) -> Self {
83 self.map_body_limit = bytes;
84 self
85 }
86
87 pub fn mount(mut self, routes: Routes) -> Self {
88 self.mounted.push(routes);
89 self
90 }
91
92 pub fn build(self) -> Result<App, String> {
93 let client = Client::builder(TokioExecutor::new()).build_http();
95
96 let mut pools = HashMap::with_capacity(self.backends.len());
98 for (svc, urls) in self.backends {
99 let bases = urls
100 .iter()
101 .map(|u| BaseUri::parse(u))
102 .collect::<Result<Vec<_>, _>>()?;
103 pools.insert(svc, BackendPool::new(bases));
104 }
105
106 let inner = Arc::new(Inner {
107 backends: pools,
108 client,
109 _default_timeout: self.default_timeout,
110 map_body_limit: self.map_body_limit,
111 });
112
113 let mut router = Router::new();
115
116 for svc_routes in self.mounted {
117 if !inner.backends.contains_key(svc_routes.service) {
118 return Err(format!(
119 "backend for service `{}` is not registered",
120 svc_routes.service
121 ));
122 }
123
124 router = mount_service(
125 router,
126 svc_routes,
127 &self.policies,
128 self.default_policy.clone(),
129 )?;
130 }
131
132 let router = router.with_state(inner);
133
134 Ok(App { router })
135 }
136}
137
138impl App {
139 pub fn builder() -> AppBuilder {
140 AppBuilder::new()
141 }
142}
143
144pub async fn run(addr: SocketAddr, app: App) -> std::io::Result<()> {
145 let listener = tokio::net::TcpListener::bind(addr).await?;
146 axum::serve(listener, app.router).await
148}
149
150fn mount_service(
151 mut router: Router<Arc<Inner>>,
152 routes: Routes,
153 policies: &HashMap<String, Arc<Policy>>,
154 default_policy: Arc<Policy>,
155) -> Result<Router<Arc<Inner>>, String> {
156 for rd in routes.routes {
168 let full_path = join(routes.prefix, rd.path);
169
170 let policy = resolve_policy(routes.policy, rd.policy, policies, &default_policy)?;
171
172 let meta = RouteMeta {
173 service: routes.service,
174 route_path: rd.path,
175 prefix: routes.prefix,
176 rewrite: match rd.rewrite {
177 RewriteSpec::StripPrefix => Rewrite::StripPrefix,
178 RewriteSpec::Static(to) => Rewrite::Static(FixedRewrite::new(to)),
179 RewriteSpec::Template(tpl) => Rewrite::Template(tpl),
180 },
181 policy,
182 before: rd.before,
183 map: rd.map,
184 };
185
186 let method_router = method_router(rd.method).layer(Extension(meta));
187
188 router = router.route(&full_path, method_router);
189 }
190
191 Ok(router)
192}
193
194fn resolve_policy(
195 service_policy: Option<&'static str>,
196 route_policy: Option<&'static str>,
197 registry: &HashMap<String, Arc<Policy>>,
198 default_policy: &Arc<Policy>,
199) -> Result<Arc<Policy>, String> {
200 let effective = route_policy.or(service_policy);
201
202 match effective {
203 Some(name) => registry
204 .get(name)
205 .cloned()
206 .ok_or_else(|| format!("policy `{name}` is not registered")),
207 None => Ok(default_policy.clone()),
208 }
209}
210
211fn join(prefix: &str, path: &str) -> String {
212 let mut s = String::with_capacity(prefix.len() + path.len());
215 if prefix.ends_with('/') {
216 s.push_str(prefix.trim_end_matches('/'));
217 } else {
218 s.push_str(prefix);
219 }
220 if path.starts_with('/') {
221 s.push_str(path);
222 } else {
223 s.push('/');
224 s.push_str(path);
225 }
226 s
227}
228
229fn method_router(method: Method) -> routing::MethodRouter<Arc<Inner>> {
230 match method {
231 Method::Get => routing::get(proxy_handler),
232 Method::Post => routing::post(proxy_handler),
233 Method::Put => routing::put(proxy_handler),
234 Method::Delete => routing::delete(proxy_handler),
235 Method::Patch => routing::patch(proxy_handler),
236 Method::Head => routing::head(proxy_handler),
237 Method::Options => routing::options(proxy_handler),
238 }
239}
240
241async fn proxy_handler(
242 State(inner): State<Arc<Inner>>,
243 Extension(meta): Extension<RouteMeta>,
244 mut req: AxumRequest,
245) -> axum::response::Response {
246 let Some(pool) = inner.backends.get(meta.service) else {
247 return bad_gateway(format!("unknown backend `{}`", meta.service));
248 };
249
250 if let Some(before) = meta.before {
252 let (mut parts, body) = req.into_parts();
253 let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
254
255 if let Err(err) = before(ctx).await {
256 return err.into_response();
257 }
258
259 req = http::Request::from_parts(parts, body);
260 }
261
262 if let Some(map) = meta.map {
264 let (mut parts, body) = req.into_parts();
265 let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
266
267 let body = match map(ctx, body, inner.map_body_limit).await {
268 Ok(body) => body,
269 Err(err) => return err.into_response(),
270 };
271
272 req = http::Request::from_parts(parts, body);
273 }
274
275 let route_ctx = RouteCtx {
277 service: meta.service,
278 route_path: meta.route_path,
279 method: req.method(),
280 uri: req.uri(),
281 headers: req.headers(),
282 };
283 let routing = meta.policy.router.route(&route_ctx, pool);
284
285 let balance_ctx = BalanceCtx {
287 service: meta.service,
288 affinity: routing.affinity.as_ref(),
289 pool,
290 candidates: routing.candidates,
291 };
292 let Some(backend_index) = meta.policy.balancer.pick(&balance_ctx) else {
293 return bad_gateway("no backends selected by balancer");
294 };
295 let Some(backend) = pool.get(backend_index) else {
296 return bad_gateway("balancer returned invalid backend index");
297 };
298
299 meta.policy.balancer.on_start(&StartEvent {
301 service: meta.service,
302 backend_index,
303 });
304
305 let started_at = Instant::now();
306
307 match proxy_request(backend, &inner.client, &meta, req).await {
308 Ok(response) => {
309 let elapsed = started_at.elapsed();
310
311 meta.policy.balancer.on_result(&ResultEvent {
312 service: meta.service,
313 backend_index,
314 status: Some(response.status()),
315 error: None,
316 head_latency: elapsed,
317 });
318
319 response
320 }
321 Err(error) => {
322 let elapsed = started_at.elapsed();
323
324 meta.policy.balancer.on_result(&ResultEvent {
325 service: meta.service,
326 backend_index,
327 status: None,
328 error: Some(error),
329 head_latency: elapsed,
330 });
331
332 match error {
333 ProxyErrorKind::NoBackends => bad_gateway("no backends"),
334 ProxyErrorKind::InvalidUpstreamUri => bad_gateway("bad upstream uri"),
335 ProxyErrorKind::UpstreamRequest => bad_gateway("upstream request failed"),
336 }
337 }
338 }
339}