Skip to main content

apigate_core/
app.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use axum::body::Body;
7use axum::extract::{Request as AxumRequest, State};
8use axum::response::IntoResponse;
9use axum::routing;
10use axum::{Extension, Router};
11
12use hyper_util::client::legacy::Client;
13use hyper_util::client::legacy::connect::HttpConnector;
14use hyper_util::rt::{TokioExecutor, TokioTimer};
15
16use crate::backend::{BackendPool, BaseUri};
17use crate::balancing::{BalanceCtx, ProxyErrorKind, ResultEvent, RoundRobin, StartEvent};
18use crate::policy::Policy;
19use crate::proxy::{bad_gateway, gateway_timeout, proxy_request};
20use crate::route::{FixedRewrite, Rewrite, RewriteSpec, RouteMeta};
21use crate::routing::{NoRouteKey, RouteCtx};
22use crate::{Method, PartsCtx, RequestScope, Routes};
23
24struct Inner {
25    client: Client<HttpConnector, Body>,
26    state: Arc<http::Extensions>,
27    map_body_limit: usize,
28    request_timeout: Duration,
29}
30
31pub struct App {
32    router: Router,
33}
34
35pub struct AppBuilder {
36    backends: HashMap<String, Vec<String>>,
37    mounted: Vec<Routes>,
38    policies: HashMap<String, Arc<Policy>>,
39    default_policy: Arc<Policy>,
40    request_timeout: Duration,
41    connect_timeout: Duration,
42    pool_idle_timeout: Duration,
43    map_body_limit: usize,
44    state: http::Extensions,
45}
46
47impl AppBuilder {
48    pub fn new() -> Self {
49        Self {
50            backends: HashMap::new(),
51            mounted: Vec::new(),
52            policies: HashMap::new(),
53            default_policy: Arc::new(Policy::new().router(NoRouteKey).balancer(RoundRobin::new())),
54            request_timeout: Duration::from_secs(30),
55            connect_timeout: Duration::from_secs(5),
56            pool_idle_timeout: Duration::from_secs(90),
57            map_body_limit: 2 * 1024 * 1024,
58            state: http::Extensions::new(),
59        }
60    }
61
62    pub fn backend<I, S>(mut self, service: &str, urls: I) -> Self
63    where
64        I: IntoIterator<Item = S>,
65        S: Into<String>,
66    {
67        self.backends.insert(
68            service.to_string(),
69            urls.into_iter().map(|s| s.into()).collect(),
70        );
71        self
72    }
73
74    pub fn policy(mut self, name: &str, policy: Policy) -> Self {
75        self.policies.insert(name.to_string(), Arc::new(policy));
76        self
77    }
78
79    pub fn default_policy(mut self, policy: Policy) -> Self {
80        self.default_policy = Arc::new(policy);
81        self
82    }
83
84    pub fn request_timeout(mut self, d: Duration) -> Self {
85        self.request_timeout = d;
86        self
87    }
88
89    pub fn connect_timeout(mut self, d: Duration) -> Self {
90        self.connect_timeout = d;
91        self
92    }
93
94    pub fn pool_idle_timeout(mut self, d: Duration) -> Self {
95        self.pool_idle_timeout = d;
96        self
97    }
98
99    pub fn map_body_limit(mut self, bytes: usize) -> Self {
100        self.map_body_limit = bytes;
101        self
102    }
103
104    pub fn state<T: Clone + Send + Sync + 'static>(mut self, val: T) -> Self {
105        self.state.insert(val);
106        self
107    }
108
109    pub fn mount(mut self, routes: Routes) -> Self {
110        self.mounted.push(routes);
111        self
112    }
113
114    pub fn build(self) -> Result<App, String> {
115        // HTTP client
116        let mut connector = HttpConnector::new();
117        connector.set_nodelay(true);
118        connector.set_connect_timeout(Some(self.connect_timeout));
119        connector.set_keepalive(Some(self.pool_idle_timeout));
120
121        let client = Client::builder(TokioExecutor::new())
122            .pool_timer(TokioTimer::new())
123            .pool_idle_timeout(self.pool_idle_timeout)
124            .build(connector);
125
126        // backend pools
127        let pools: HashMap<_, _> = self
128            .backends
129            .into_iter()
130            .map(|(svc, urls)| {
131                let bases = urls
132                    .iter()
133                    .map(|u| BaseUri::parse(u))
134                    .collect::<Result<Vec<_>, _>>()?;
135                Ok((svc, Arc::new(BackendPool::new(bases))))
136            })
137            .collect::<Result<_, String>>()?;
138
139        let inner = Arc::new(Inner {
140            client,
141            state: Arc::new(self.state),
142            request_timeout: self.request_timeout,
143            map_body_limit: self.map_body_limit,
144        });
145
146        // build router with state
147        let mut router = Router::new();
148
149        for svc_routes in self.mounted {
150            let pool = pools
151                .get(svc_routes.service)
152                .ok_or_else(|| {
153                    format!(
154                        "backend for service `{}` is not registered",
155                        svc_routes.service,
156                    )
157                })?
158                .clone();
159
160            router = mount_service(
161                router,
162                svc_routes,
163                &self.policies,
164                self.default_policy.clone(),
165                pool,
166            )?;
167        }
168
169        let router = router.with_state(inner);
170
171        Ok(App { router })
172    }
173}
174
175impl App {
176    pub fn builder() -> AppBuilder {
177        AppBuilder::new()
178    }
179}
180
181pub async fn run(addr: SocketAddr, app: App) -> std::io::Result<()> {
182    let listener = tokio::net::TcpListener::bind(addr).await?;
183    // axum::serve intentionally simple (и это нам подходит как внутренняя обертка)
184    axum::serve(listener, app.router).await
185}
186
187fn mount_service(
188    mut router: Router<Arc<Inner>>,
189    routes: Routes,
190    policies: &HashMap<String, Arc<Policy>>,
191    default_policy: Arc<Policy>,
192    pool: Arc<BackendPool>,
193) -> Result<Router<Arc<Inner>>, String> {
194    // проверим что backend зарегистрирован
195    // (в минимальной версии лучше фейлиться сразу)
196    // В будущем тут будет ServiceRegistry/PolicyRegistry.
197    //
198    // NOTE: State<Inner> у нас хранит HashMap<String, BackendPool>.
199    // Мы проверяем наличие ключа на запросе в handler'е тоже, но ранняя проверка приятнее.
200    // Т.к. mounted routes имеют &'static str service, мы делаем check на build-time.
201    //
202    // Здесь нет доступа к Inner.backends (он внутри Arc), поэтому check сделаем в handler'е.
203    // (Можно перестроить на builder-time, но это минимальная версия.)
204
205    for rd in routes.routes {
206        let full_path = join(routes.prefix, rd.path);
207
208        let policy = resolve_policy(routes.policy, rd.policy, policies, &default_policy)?;
209
210        let meta = Arc::new(RouteMeta {
211            service: routes.service,
212            route_path: rd.path,
213            prefix: routes.prefix,
214            rewrite: match rd.rewrite {
215                RewriteSpec::StripPrefix => Rewrite::StripPrefix,
216                RewriteSpec::Static(to) => Rewrite::Static(FixedRewrite::new(to)),
217                RewriteSpec::Template(tpl) => Rewrite::Template(tpl),
218            },
219            pool: Arc::clone(&pool),
220            policy,
221            pipeline: rd.pipeline,
222        });
223
224        let method_router = method_router(rd.method).layer(Extension(meta));
225
226        router = router.route(&full_path, method_router);
227    }
228
229    Ok(router)
230}
231
232fn resolve_policy(
233    service_policy: Option<&'static str>,
234    route_policy: Option<&'static str>,
235    registry: &HashMap<String, Arc<Policy>>,
236    default_policy: &Arc<Policy>,
237) -> Result<Arc<Policy>, String> {
238    let effective = route_policy.or(service_policy);
239
240    match effective {
241        Some(name) => registry
242            .get(name)
243            .cloned()
244            .ok_or_else(|| format!("policy `{name}` is not registered")),
245        None => Ok(default_policy.clone()),
246    }
247}
248
249fn join(prefix: &str, path: &str) -> String {
250    // prefix="/sales", path="/ping" => "/sales/ping"
251    // prefix="/sales", path="/" => "/sales/"
252    let mut s = String::with_capacity(prefix.len() + path.len());
253    if prefix.ends_with('/') {
254        s.push_str(prefix.trim_end_matches('/'));
255    } else {
256        s.push_str(prefix);
257    }
258    if path.starts_with('/') {
259        s.push_str(path);
260    } else {
261        s.push('/');
262        s.push_str(path);
263    }
264    s
265}
266
267fn method_router(method: Method) -> routing::MethodRouter<Arc<Inner>> {
268    match method {
269        Method::Get => routing::get(proxy_handler),
270        Method::Post => routing::post(proxy_handler),
271        Method::Put => routing::put(proxy_handler),
272        Method::Delete => routing::delete(proxy_handler),
273        Method::Patch => routing::patch(proxy_handler),
274        Method::Head => routing::head(proxy_handler),
275        Method::Options => routing::options(proxy_handler),
276    }
277}
278
279async fn proxy_handler(
280    State(inner): State<Arc<Inner>>,
281    Extension(meta): Extension<Arc<RouteMeta>>,
282    req: AxumRequest,
283) -> axum::response::Response {
284    let pool = &meta.pool;
285    let (mut parts, body) = req.into_parts();
286
287    // Pipeline: before hooks + body validation/map in a single pass
288    let body = if let Some(pipeline) = meta.pipeline {
289        let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
290        let scope = RequestScope::with_shared(Arc::clone(&inner.state), body, inner.map_body_limit);
291
292        match pipeline(ctx, scope).await {
293            Ok(body) => body,
294            Err(err) => return err.into_response(),
295        }
296    } else {
297        body
298    };
299
300    // Routing
301    let route_ctx = RouteCtx {
302        service: meta.service,
303        prefix: meta.prefix,
304        route_path: meta.route_path,
305        method: &parts.method,
306        uri: &parts.uri,
307        headers: &parts.headers,
308    };
309    let routing = meta.policy.router.route(&route_ctx, pool);
310
311    // Balancer
312    let balance_ctx = BalanceCtx {
313        service: meta.service,
314        affinity: routing.affinity.as_ref(),
315        pool,
316        candidates: routing.candidates,
317    };
318    let Some(backend_index) = meta.policy.balancer.pick(&balance_ctx) else {
319        return bad_gateway("no backends selected by balancer");
320    };
321    let Some(backend) = pool.get(backend_index) else {
322        return bad_gateway("balancer returned invalid backend index");
323    };
324
325    // Make request
326    meta.policy.balancer.on_start(&StartEvent {
327        service: meta.service,
328        backend_index,
329    });
330
331    let started_at = Instant::now();
332
333    let result = tokio::time::timeout(
334        inner.request_timeout,
335        proxy_request(backend, &inner.client, &meta, parts, body),
336    )
337    .await
338    .unwrap_or_else(|_| Err(ProxyErrorKind::Timeout));
339
340    match result {
341        Ok(response) => {
342            let elapsed = started_at.elapsed();
343
344            meta.policy.balancer.on_result(&ResultEvent {
345                service: meta.service,
346                backend_index,
347                status: Some(response.status()),
348                error: None,
349                head_latency: elapsed,
350            });
351
352            response
353        }
354        Err(error) => {
355            let elapsed = started_at.elapsed();
356
357            meta.policy.balancer.on_result(&ResultEvent {
358                service: meta.service,
359                backend_index,
360                status: None,
361                error: Some(error),
362                head_latency: elapsed,
363            });
364
365            match error {
366                ProxyErrorKind::NoBackends => bad_gateway("no backends"),
367                ProxyErrorKind::InvalidUpstreamUri => bad_gateway("bad upstream uri"),
368                ProxyErrorKind::UpstreamRequest => bad_gateway("upstream request failed"),
369                ProxyErrorKind::Timeout => gateway_timeout("upstream request timed out"),
370            }
371        }
372    }
373}