1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use axum::body::Body;
7use axum::extract::{Request as AxumRequest, State};
8use axum::routing;
9use axum::{Extension, Router};
10
11use hyper_util::client::legacy::Client;
12use hyper_util::client::legacy::connect::HttpConnector;
13use hyper_util::rt::{TokioExecutor, TokioTimer};
14
15use crate::backend::{BackendPool, BaseUri};
16use crate::balancing::{BalanceCtx, ProxyErrorKind, ResultEvent, RoundRobin, StartEvent};
17use crate::error::{
18 ApigateBuildError, ApigateCoreError, ApigateFrameworkError, ErrorRenderer,
19 default_error_renderer,
20};
21use crate::policy::Policy;
22use crate::proxy::proxy_request;
23use crate::route::{FixedRewrite, Rewrite, RewriteSpec, RouteMeta};
24use crate::routing::{NoRouteKey, RouteCtx};
25use crate::{ApigateError, Method, PartsCtx, RequestScope, Routes};
26
27struct Inner {
28 client: Client<HttpConnector, Body>,
29 state: http::Extensions,
30 map_body_limit: usize,
31 request_timeout: Duration,
32 route_metas: Box<[RouteMeta]>,
33 error_renderer: Arc<ErrorRenderer>,
34}
35
36pub struct App {
37 router: Router,
38}
39
40pub struct AppBuilder {
41 backends: HashMap<String, Vec<String>>,
42 mounted: Vec<Routes>,
43 policies: HashMap<String, Arc<Policy>>,
44 default_policy: Arc<Policy>,
45 request_timeout: Duration,
46 connect_timeout: Duration,
47 pool_idle_timeout: Duration,
48 map_body_limit: usize,
49 state: http::Extensions,
50 error_renderer: Arc<ErrorRenderer>,
51}
52
53impl AppBuilder {
54 pub fn new() -> Self {
55 Self {
56 backends: HashMap::new(),
57 mounted: Vec::new(),
58 policies: HashMap::new(),
59 default_policy: Arc::new(Policy::new().router(NoRouteKey).balancer(RoundRobin::new())),
60 request_timeout: Duration::from_secs(30),
61 connect_timeout: Duration::from_secs(5),
62 pool_idle_timeout: Duration::from_secs(90),
63 map_body_limit: 2 * 1024 * 1024,
64 state: http::Extensions::new(),
65 error_renderer: Arc::new(default_error_renderer),
66 }
67 }
68
69 pub fn backend<I, S>(mut self, service: &str, urls: I) -> Self
70 where
71 I: IntoIterator<Item = S>,
72 S: Into<String>,
73 {
74 self.backends.insert(
75 service.to_string(),
76 urls.into_iter().map(|s| s.into()).collect(),
77 );
78 self
79 }
80
81 pub fn policy(mut self, name: &str, policy: Policy) -> Self {
82 self.policies.insert(name.to_string(), Arc::new(policy));
83 self
84 }
85
86 pub fn default_policy(mut self, policy: Policy) -> Self {
87 self.default_policy = Arc::new(policy);
88 self
89 }
90
91 pub fn request_timeout(mut self, d: Duration) -> Self {
92 self.request_timeout = d;
93 self
94 }
95
96 pub fn connect_timeout(mut self, d: Duration) -> Self {
97 self.connect_timeout = d;
98 self
99 }
100
101 pub fn pool_idle_timeout(mut self, d: Duration) -> Self {
102 self.pool_idle_timeout = d;
103 self
104 }
105
106 pub fn map_body_limit(mut self, bytes: usize) -> Self {
107 self.map_body_limit = bytes;
108 self
109 }
110
111 pub fn state<T: Clone + Send + Sync + 'static>(mut self, val: T) -> Self {
112 self.state.insert(val);
113 self
114 }
115
116 pub fn error_renderer<F>(mut self, renderer: F) -> Self
121 where
122 F: Fn(ApigateFrameworkError) -> axum::response::Response + Send + Sync + 'static,
123 {
124 self.error_renderer = Arc::new(renderer);
125 self
126 }
127
128 pub fn mount(mut self, routes: Routes) -> Self {
129 self.mounted.push(routes);
130 self
131 }
132
133 pub fn mount_service<I, S>(mut self, routes: Routes, urls: I) -> Self
138 where
139 I: IntoIterator<Item = S>,
140 S: Into<String>,
141 {
142 self.backends.insert(
143 routes.service.to_string(),
144 urls.into_iter().map(|s| s.into()).collect(),
145 );
146 self.mounted.push(routes);
147 self
148 }
149
150 pub fn build(self) -> Result<App, ApigateBuildError> {
151 let mut connector = HttpConnector::new();
153 connector.set_nodelay(true);
154 connector.set_connect_timeout(Some(self.connect_timeout));
155 connector.set_keepalive(Some(self.pool_idle_timeout));
156
157 let client = Client::builder(TokioExecutor::new())
158 .pool_timer(TokioTimer::new())
159 .pool_idle_timeout(self.pool_idle_timeout)
160 .build(connector);
161
162 let mut pools: HashMap<String, Arc<BackendPool>> = HashMap::new();
164 for (svc, urls) in self.backends {
165 let mut bases = Vec::with_capacity(urls.len());
166 for url in urls {
167 let base = match BaseUri::parse(&url) {
168 Ok(base) => base,
169 Err(source) => {
170 return Err(ApigateBuildError::InvalidBackendUri {
171 service: svc.clone(),
172 url,
173 source,
174 });
175 }
176 };
177 bases.push(base);
178 }
179 pools.insert(svc, Arc::new(BackendPool::new(bases)));
180 }
181
182 let mut router = Router::new();
184 let mut route_metas = Vec::new();
185
186 for svc_routes in self.mounted {
187 let pool = pools
188 .get(svc_routes.service)
189 .ok_or(ApigateBuildError::BackendNotRegistered {
190 service: svc_routes.service,
191 })?
192 .clone();
193
194 router = mount_routes(
195 router,
196 svc_routes,
197 &self.policies,
198 self.default_policy.clone(),
199 pool,
200 &mut route_metas,
201 )?;
202 }
203
204 let inner = Arc::new(Inner {
205 client,
206 state: self.state,
207 request_timeout: self.request_timeout,
208 map_body_limit: self.map_body_limit,
209 route_metas: route_metas.into_boxed_slice(),
210 error_renderer: self.error_renderer,
211 });
212
213 let router = router.with_state(inner);
214
215 Ok(App { router })
216 }
217}
218
219impl Default for AppBuilder {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225impl App {
226 pub fn builder() -> AppBuilder {
227 AppBuilder::new()
228 }
229}
230
231pub async fn run(addr: SocketAddr, app: App) -> std::io::Result<()> {
232 let listener = tokio::net::TcpListener::bind(addr).await?;
233 axum::serve(listener, app.router).await
235}
236
237fn mount_routes(
238 mut router: Router<Arc<Inner>>,
239 routes: Routes,
240 policies: &HashMap<String, Arc<Policy>>,
241 default_policy: Arc<Policy>,
242 pool: Arc<BackendPool>,
243 route_metas: &mut Vec<RouteMeta>,
244) -> Result<Router<Arc<Inner>>, ApigateBuildError> {
245 for rd in routes.routes {
246 let full_path = join(routes.prefix, rd.path);
247 let policy = resolve_policy(routes.policy, rd.policy, policies, &default_policy)?;
248
249 let meta = RouteMeta {
250 service: routes.service,
251 route_path: rd.path,
252 prefix: routes.prefix,
253 rewrite: match rd.rewrite {
254 RewriteSpec::StripPrefix => Rewrite::StripPrefix,
255 RewriteSpec::Static(to) => Rewrite::Static(FixedRewrite::new(to)),
256 RewriteSpec::Template(tpl) => Rewrite::Template(tpl),
257 },
258 pool: Arc::clone(&pool),
259 policy,
260 pipeline: rd.pipeline,
261 };
262
263 let route_idx = route_metas.len();
264 route_metas.push(meta);
265
266 let method_router = method_router(rd.method).layer(Extension(route_idx));
267
268 router = router.route(&full_path, method_router);
269 }
270
271 Ok(router)
272}
273
274fn resolve_policy(
275 service_policy: Option<&'static str>,
276 route_policy: Option<&'static str>,
277 registry: &HashMap<String, Arc<Policy>>,
278 default_policy: &Arc<Policy>,
279) -> Result<Arc<Policy>, ApigateBuildError> {
280 let effective = route_policy.or(service_policy);
281
282 match effective {
283 Some(name) => registry
284 .get(name)
285 .cloned()
286 .ok_or(ApigateBuildError::PolicyNotRegistered { name }),
287 None => Ok(default_policy.clone()),
288 }
289}
290
291fn join(prefix: &str, path: &str) -> String {
292 let mut s = String::with_capacity(prefix.len() + path.len());
295 if prefix.ends_with('/') {
296 s.push_str(prefix.trim_end_matches('/'));
297 } else {
298 s.push_str(prefix);
299 }
300 if path.starts_with('/') {
301 s.push_str(path);
302 } else {
303 s.push('/');
304 s.push_str(path);
305 }
306 s
307}
308
309fn method_router(method: Method) -> routing::MethodRouter<Arc<Inner>> {
310 match method {
311 Method::Get => routing::get(proxy_handler),
312 Method::Post => routing::post(proxy_handler),
313 Method::Put => routing::put(proxy_handler),
314 Method::Delete => routing::delete(proxy_handler),
315 Method::Patch => routing::patch(proxy_handler),
316 Method::Head => routing::head(proxy_handler),
317 Method::Options => routing::options(proxy_handler),
318 }
319}
320
321async fn proxy_handler(
322 State(inner): State<Arc<Inner>>,
323 Extension(route_idx): Extension<usize>,
324 req: AxumRequest,
325) -> axum::response::Response {
326 let meta = &inner.route_metas[route_idx];
327 let pool = &meta.pool;
328 let (mut parts, body) = req.into_parts();
329
330 let body = if let Some(pipeline) = meta.pipeline {
332 let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
333 let scope = RequestScope::new(&inner.state, body, inner.map_body_limit);
334
335 match pipeline(ctx, scope).await {
336 Ok(body) => body,
337 Err(err) => return err.into_response_with(inner.error_renderer.as_ref()),
338 }
339 } else {
340 body
341 };
342
343 let route_ctx = RouteCtx {
345 service: meta.service,
346 prefix: meta.prefix,
347 route_path: meta.route_path,
348 method: &parts.method,
349 uri: &parts.uri,
350 headers: &parts.headers,
351 };
352 let routing = meta.policy.router.route(&route_ctx, pool);
353
354 let balance_ctx = BalanceCtx {
356 service: meta.service,
357 affinity: routing.affinity.as_ref(),
358 pool,
359 candidates: routing.candidates,
360 };
361 let Some(backend_index) = meta.policy.balancer.pick(&balance_ctx) else {
362 return ApigateError::from(ApigateCoreError::NoBackendsSelectedByBalancer)
363 .into_response_with(inner.error_renderer.as_ref());
364 };
365 let Some(backend) = pool.get(backend_index) else {
366 return ApigateError::from(ApigateCoreError::InvalidBackendIndex)
367 .into_response_with(inner.error_renderer.as_ref());
368 };
369
370 meta.policy.balancer.on_start(&StartEvent {
372 service: meta.service,
373 backend_index,
374 });
375
376 let started_at = Instant::now();
377
378 let result = tokio::time::timeout(
379 inner.request_timeout,
380 proxy_request(backend, &inner.client, meta, parts, body),
381 )
382 .await
383 .unwrap_or_else(|_| Err(ProxyErrorKind::Timeout));
384
385 match result {
386 Ok(response) => {
387 let elapsed = started_at.elapsed();
388
389 meta.policy.balancer.on_result(&ResultEvent {
390 service: meta.service,
391 backend_index,
392 status: Some(response.status()),
393 error: None,
394 head_latency: elapsed,
395 });
396
397 response
398 }
399 Err(error) => {
400 let elapsed = started_at.elapsed();
401
402 meta.policy.balancer.on_result(&ResultEvent {
403 service: meta.service,
404 backend_index,
405 status: None,
406 error: Some(error),
407 head_latency: elapsed,
408 });
409
410 match error {
411 ProxyErrorKind::NoBackends => ApigateError::from(ApigateCoreError::NoBackends)
412 .into_response_with(inner.error_renderer.as_ref()),
413 ProxyErrorKind::InvalidUpstreamUri => {
414 ApigateError::from(ApigateCoreError::InvalidUpstreamUri)
415 .into_response_with(inner.error_renderer.as_ref())
416 }
417 ProxyErrorKind::UpstreamRequest => {
418 ApigateError::from(ApigateCoreError::UpstreamRequestFailed)
419 .into_response_with(inner.error_renderer.as_ref())
420 }
421 ProxyErrorKind::Timeout => {
422 ApigateError::from(ApigateCoreError::UpstreamRequestTimedOut)
423 .into_response_with(inner.error_renderer.as_ref())
424 }
425 }
426 }
427 }
428}