Skip to main content

apigate_core/
app.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use axum::body::Body;
7use axum::extract::{Request as AxumRequest, State};
8use axum::routing;
9use axum::{Extension, Router};
10
11use hyper_util::client::legacy::Client;
12use hyper_util::client::legacy::connect::HttpConnector;
13use hyper_util::rt::{TokioExecutor, TokioTimer};
14
15use crate::backend::{BackendPool, BaseUri};
16use crate::balancing::{BalanceCtx, ProxyErrorKind, ResultEvent, RoundRobin, StartEvent};
17use crate::error::{
18    ApigateBuildError, ApigateCoreError, ApigateFrameworkError, ErrorRenderer,
19    default_error_renderer,
20};
21use crate::policy::Policy;
22use crate::proxy::proxy_request;
23use crate::route::{FixedRewrite, Rewrite, RewriteSpec, RouteMeta};
24use crate::routing::{NoRouteKey, RouteCtx};
25use crate::{ApigateError, Method, PartsCtx, RequestScope, Routes};
26
27struct Inner {
28    client: Client<HttpConnector, Body>,
29    state: http::Extensions,
30    map_body_limit: usize,
31    request_timeout: Duration,
32    route_metas: Box<[RouteMeta]>,
33    error_renderer: Arc<ErrorRenderer>,
34}
35
36pub struct App {
37    router: Router,
38}
39
40pub struct AppBuilder {
41    backends: HashMap<String, Vec<String>>,
42    mounted: Vec<Routes>,
43    policies: HashMap<String, Arc<Policy>>,
44    default_policy: Arc<Policy>,
45    request_timeout: Duration,
46    connect_timeout: Duration,
47    pool_idle_timeout: Duration,
48    map_body_limit: usize,
49    state: http::Extensions,
50    error_renderer: Arc<ErrorRenderer>,
51}
52
53impl AppBuilder {
54    pub fn new() -> Self {
55        Self {
56            backends: HashMap::new(),
57            mounted: Vec::new(),
58            policies: HashMap::new(),
59            default_policy: Arc::new(Policy::new().router(NoRouteKey).balancer(RoundRobin::new())),
60            request_timeout: Duration::from_secs(30),
61            connect_timeout: Duration::from_secs(5),
62            pool_idle_timeout: Duration::from_secs(90),
63            map_body_limit: 2 * 1024 * 1024,
64            state: http::Extensions::new(),
65            error_renderer: Arc::new(default_error_renderer),
66        }
67    }
68
69    pub fn backend<I, S>(mut self, service: &str, urls: I) -> Self
70    where
71        I: IntoIterator<Item = S>,
72        S: Into<String>,
73    {
74        self.backends.insert(
75            service.to_string(),
76            urls.into_iter().map(|s| s.into()).collect(),
77        );
78        self
79    }
80
81    pub fn policy(mut self, name: &str, policy: Policy) -> Self {
82        self.policies.insert(name.to_string(), Arc::new(policy));
83        self
84    }
85
86    pub fn default_policy(mut self, policy: Policy) -> Self {
87        self.default_policy = Arc::new(policy);
88        self
89    }
90
91    pub fn request_timeout(mut self, d: Duration) -> Self {
92        self.request_timeout = d;
93        self
94    }
95
96    pub fn connect_timeout(mut self, d: Duration) -> Self {
97        self.connect_timeout = d;
98        self
99    }
100
101    pub fn pool_idle_timeout(mut self, d: Duration) -> Self {
102        self.pool_idle_timeout = d;
103        self
104    }
105
106    pub fn map_body_limit(mut self, bytes: usize) -> Self {
107        self.map_body_limit = bytes;
108        self
109    }
110
111    pub fn state<T: Clone + Send + Sync + 'static>(mut self, val: T) -> Self {
112        self.state.insert(val);
113        self
114    }
115
116    /// Sets the renderer for framework-generated errors (`ApigateError::*` constructors).
117    ///
118    /// This lets applications return a uniform JSON error envelope instead of plain text.
119    /// The renderer is used both for pipeline errors and proxy/runtime errors (502/504, etc.).
120    pub fn error_renderer<F>(mut self, renderer: F) -> Self
121    where
122        F: Fn(ApigateFrameworkError) -> axum::response::Response + Send + Sync + 'static,
123    {
124        self.error_renderer = Arc::new(renderer);
125        self
126    }
127
128    pub fn mount(mut self, routes: Routes) -> Self {
129        self.mounted.push(routes);
130        self
131    }
132
133    /// Registers backend URLs for `routes.service` and mounts these routes.
134    ///
135    /// Equivalent to:
136    /// `builder.backend(routes.service, urls).mount(routes)`
137    pub fn mount_service<I, S>(mut self, routes: Routes, urls: I) -> Self
138    where
139        I: IntoIterator<Item = S>,
140        S: Into<String>,
141    {
142        self.backends.insert(
143            routes.service.to_string(),
144            urls.into_iter().map(|s| s.into()).collect(),
145        );
146        self.mounted.push(routes);
147        self
148    }
149
150    pub fn build(self) -> Result<App, ApigateBuildError> {
151        // HTTP client
152        let mut connector = HttpConnector::new();
153        connector.set_nodelay(true);
154        connector.set_connect_timeout(Some(self.connect_timeout));
155        connector.set_keepalive(Some(self.pool_idle_timeout));
156
157        let client = Client::builder(TokioExecutor::new())
158            .pool_timer(TokioTimer::new())
159            .pool_idle_timeout(self.pool_idle_timeout)
160            .build(connector);
161
162        // backend pools
163        let mut pools: HashMap<String, Arc<BackendPool>> = HashMap::new();
164        for (svc, urls) in self.backends {
165            let mut bases = Vec::with_capacity(urls.len());
166            for url in urls {
167                let base = match BaseUri::parse(&url) {
168                    Ok(base) => base,
169                    Err(source) => {
170                        return Err(ApigateBuildError::InvalidBackendUri {
171                            service: svc.clone(),
172                            url,
173                            source,
174                        });
175                    }
176                };
177                bases.push(base);
178            }
179            pools.insert(svc, Arc::new(BackendPool::new(bases)));
180        }
181
182        // build router + route metadata table
183        let mut router = Router::new();
184        let mut route_metas = Vec::new();
185
186        for svc_routes in self.mounted {
187            let pool = pools
188                .get(svc_routes.service)
189                .ok_or(ApigateBuildError::BackendNotRegistered {
190                    service: svc_routes.service,
191                })?
192                .clone();
193
194            router = mount_routes(
195                router,
196                svc_routes,
197                &self.policies,
198                self.default_policy.clone(),
199                pool,
200                &mut route_metas,
201            )?;
202        }
203
204        let inner = Arc::new(Inner {
205            client,
206            state: self.state,
207            request_timeout: self.request_timeout,
208            map_body_limit: self.map_body_limit,
209            route_metas: route_metas.into_boxed_slice(),
210            error_renderer: self.error_renderer,
211        });
212
213        let router = router.with_state(inner);
214
215        Ok(App { router })
216    }
217}
218
219impl Default for AppBuilder {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225impl App {
226    pub fn builder() -> AppBuilder {
227        AppBuilder::new()
228    }
229}
230
231pub async fn run(addr: SocketAddr, app: App) -> std::io::Result<()> {
232    let listener = tokio::net::TcpListener::bind(addr).await?;
233    // axum::serve intentionally simple (и это нам подходит как внутренняя обертка)
234    axum::serve(listener, app.router).await
235}
236
237fn mount_routes(
238    mut router: Router<Arc<Inner>>,
239    routes: Routes,
240    policies: &HashMap<String, Arc<Policy>>,
241    default_policy: Arc<Policy>,
242    pool: Arc<BackendPool>,
243    route_metas: &mut Vec<RouteMeta>,
244) -> Result<Router<Arc<Inner>>, ApigateBuildError> {
245    for rd in routes.routes {
246        let full_path = join(routes.prefix, rd.path);
247        let policy = resolve_policy(routes.policy, rd.policy, policies, &default_policy)?;
248
249        let meta = RouteMeta {
250            service: routes.service,
251            route_path: rd.path,
252            prefix: routes.prefix,
253            rewrite: match rd.rewrite {
254                RewriteSpec::StripPrefix => Rewrite::StripPrefix,
255                RewriteSpec::Static(to) => Rewrite::Static(FixedRewrite::new(to)),
256                RewriteSpec::Template(tpl) => Rewrite::Template(tpl),
257            },
258            pool: Arc::clone(&pool),
259            policy,
260            pipeline: rd.pipeline,
261        };
262
263        let route_idx = route_metas.len();
264        route_metas.push(meta);
265
266        let method_router = method_router(rd.method).layer(Extension(route_idx));
267
268        router = router.route(&full_path, method_router);
269    }
270
271    Ok(router)
272}
273
274fn resolve_policy(
275    service_policy: Option<&'static str>,
276    route_policy: Option<&'static str>,
277    registry: &HashMap<String, Arc<Policy>>,
278    default_policy: &Arc<Policy>,
279) -> Result<Arc<Policy>, ApigateBuildError> {
280    let effective = route_policy.or(service_policy);
281
282    match effective {
283        Some(name) => registry
284            .get(name)
285            .cloned()
286            .ok_or(ApigateBuildError::PolicyNotRegistered { name }),
287        None => Ok(default_policy.clone()),
288    }
289}
290
291fn join(prefix: &str, path: &str) -> String {
292    // prefix="/sales", path="/ping" => "/sales/ping"
293    // prefix="/sales", path="/" => "/sales/"
294    let mut s = String::with_capacity(prefix.len() + path.len());
295    if prefix.ends_with('/') {
296        s.push_str(prefix.trim_end_matches('/'));
297    } else {
298        s.push_str(prefix);
299    }
300    if path.starts_with('/') {
301        s.push_str(path);
302    } else {
303        s.push('/');
304        s.push_str(path);
305    }
306    s
307}
308
309fn method_router(method: Method) -> routing::MethodRouter<Arc<Inner>> {
310    match method {
311        Method::Get => routing::get(proxy_handler),
312        Method::Post => routing::post(proxy_handler),
313        Method::Put => routing::put(proxy_handler),
314        Method::Delete => routing::delete(proxy_handler),
315        Method::Patch => routing::patch(proxy_handler),
316        Method::Head => routing::head(proxy_handler),
317        Method::Options => routing::options(proxy_handler),
318    }
319}
320
321async fn proxy_handler(
322    State(inner): State<Arc<Inner>>,
323    Extension(route_idx): Extension<usize>,
324    req: AxumRequest,
325) -> axum::response::Response {
326    let meta = &inner.route_metas[route_idx];
327    let pool = &meta.pool;
328    let (mut parts, body) = req.into_parts();
329
330    // Pipeline: before hooks + body validation/map in a single pass
331    let body = if let Some(pipeline) = meta.pipeline {
332        let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
333        let scope = RequestScope::new(&inner.state, body, inner.map_body_limit);
334
335        match pipeline(ctx, scope).await {
336            Ok(body) => body,
337            Err(err) => return err.into_response_with(inner.error_renderer.as_ref()),
338        }
339    } else {
340        body
341    };
342
343    // Routing
344    let route_ctx = RouteCtx {
345        service: meta.service,
346        prefix: meta.prefix,
347        route_path: meta.route_path,
348        method: &parts.method,
349        uri: &parts.uri,
350        headers: &parts.headers,
351    };
352    let routing = meta.policy.router.route(&route_ctx, pool);
353
354    // Balancer
355    let balance_ctx = BalanceCtx {
356        service: meta.service,
357        affinity: routing.affinity.as_ref(),
358        pool,
359        candidates: routing.candidates,
360    };
361    let Some(backend_index) = meta.policy.balancer.pick(&balance_ctx) else {
362        return ApigateError::from(ApigateCoreError::NoBackendsSelectedByBalancer)
363            .into_response_with(inner.error_renderer.as_ref());
364    };
365    let Some(backend) = pool.get(backend_index) else {
366        return ApigateError::from(ApigateCoreError::InvalidBackendIndex)
367            .into_response_with(inner.error_renderer.as_ref());
368    };
369
370    // Make request
371    meta.policy.balancer.on_start(&StartEvent {
372        service: meta.service,
373        backend_index,
374    });
375
376    let started_at = Instant::now();
377
378    let result = tokio::time::timeout(
379        inner.request_timeout,
380        proxy_request(backend, &inner.client, meta, parts, body),
381    )
382    .await
383    .unwrap_or_else(|_| Err(ProxyErrorKind::Timeout));
384
385    match result {
386        Ok(response) => {
387            let elapsed = started_at.elapsed();
388
389            meta.policy.balancer.on_result(&ResultEvent {
390                service: meta.service,
391                backend_index,
392                status: Some(response.status()),
393                error: None,
394                head_latency: elapsed,
395            });
396
397            response
398        }
399        Err(error) => {
400            let elapsed = started_at.elapsed();
401
402            meta.policy.balancer.on_result(&ResultEvent {
403                service: meta.service,
404                backend_index,
405                status: None,
406                error: Some(error),
407                head_latency: elapsed,
408            });
409
410            match error {
411                ProxyErrorKind::NoBackends => ApigateError::from(ApigateCoreError::NoBackends)
412                    .into_response_with(inner.error_renderer.as_ref()),
413                ProxyErrorKind::InvalidUpstreamUri => {
414                    ApigateError::from(ApigateCoreError::InvalidUpstreamUri)
415                        .into_response_with(inner.error_renderer.as_ref())
416                }
417                ProxyErrorKind::UpstreamRequest => {
418                    ApigateError::from(ApigateCoreError::UpstreamRequestFailed)
419                        .into_response_with(inner.error_renderer.as_ref())
420                }
421                ProxyErrorKind::Timeout => {
422                    ApigateError::from(ApigateCoreError::UpstreamRequestTimedOut)
423                        .into_response_with(inner.error_renderer.as_ref())
424                }
425            }
426        }
427    }
428}