Skip to main content

apigate_core/
app.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use axum::body::Body;
7use axum::extract::{Request as AxumRequest, State};
8use axum::response::IntoResponse;
9use axum::routing;
10use axum::{Extension, Router};
11
12use hyper_util::client::legacy::Client;
13use hyper_util::rt::TokioExecutor;
14
15use crate::backend::{BackendPool, BaseUri};
16use crate::balancing::{BalanceCtx, ProxyErrorKind, ResultEvent, RoundRobin, StartEvent};
17use crate::policy::Policy;
18use crate::proxy::{bad_gateway, proxy_request};
19use crate::route::{FixedRewrite, Rewrite, RewriteSpec, RouteMeta};
20use crate::routing::{NoRouteKey, RouteCtx};
21use crate::{Method, PartsCtx, Routes};
22
23struct Inner {
24    backends: HashMap<String, BackendPool>,
25    client: Client<hyper_util::client::legacy::connect::HttpConnector, Body>,
26    map_body_limit: usize,
27    _default_timeout: Duration, // пока не используется
28}
29
30pub struct App {
31    router: Router,
32}
33
34pub struct AppBuilder {
35    backends: HashMap<String, Vec<String>>,
36    mounted: Vec<Routes>,
37    policies: HashMap<String, Arc<Policy>>,
38    default_policy: Arc<Policy>,
39    default_timeout: Duration,
40    map_body_limit: usize,
41}
42
43impl AppBuilder {
44    pub fn new() -> Self {
45        Self {
46            backends: HashMap::new(),
47            mounted: Vec::new(),
48            policies: HashMap::new(),
49            default_policy: Arc::new(Policy::new().router(NoRouteKey).balancer(RoundRobin::new())),
50            default_timeout: Duration::from_secs(30),
51            map_body_limit: 2 * 1024 * 1024,
52        }
53    }
54
55    pub fn backend<I, S>(mut self, service: &str, urls: I) -> Self
56    where
57        I: IntoIterator<Item = S>,
58        S: Into<String>,
59    {
60        self.backends.insert(
61            service.to_string(),
62            urls.into_iter().map(|s| s.into()).collect(),
63        );
64        self
65    }
66
67    pub fn policy(mut self, name: &str, policy: Policy) -> Self {
68        self.policies.insert(name.to_string(), Arc::new(policy));
69        self
70    }
71
72    pub fn default_policy(mut self, policy: Policy) -> Self {
73        self.default_policy = Arc::new(policy);
74        self
75    }
76
77    pub fn default_timeout(mut self, d: Duration) -> Self {
78        self.default_timeout = d;
79        self
80    }
81
82    pub fn map_body_limit(mut self, bytes: usize) -> Self {
83        self.map_body_limit = bytes;
84        self
85    }
86
87    pub fn mount(mut self, routes: Routes) -> Self {
88        self.mounted.push(routes);
89        self
90    }
91
92    pub fn build(self) -> Result<App, String> {
93        // client (pooling внутри hyper-util)
94        let client = Client::builder(TokioExecutor::new()).build_http();
95
96        // backend pools
97        let mut pools = HashMap::with_capacity(self.backends.len());
98        for (svc, urls) in self.backends {
99            let bases = urls
100                .iter()
101                .map(|u| BaseUri::parse(u))
102                .collect::<Result<Vec<_>, _>>()?;
103            pools.insert(svc, BackendPool::new(bases));
104        }
105
106        let inner = Arc::new(Inner {
107            backends: pools,
108            client,
109            _default_timeout: self.default_timeout,
110            map_body_limit: self.map_body_limit,
111        });
112
113        // build router with state
114        let mut router = Router::new();
115
116        for svc_routes in self.mounted {
117            if !inner.backends.contains_key(svc_routes.service) {
118                return Err(format!(
119                    "backend for service `{}` is not registered",
120                    svc_routes.service
121                ));
122            }
123
124            router = mount_service(
125                router,
126                svc_routes,
127                &self.policies,
128                self.default_policy.clone(),
129            )?;
130        }
131
132        let router = router.with_state(inner);
133
134        Ok(App { router })
135    }
136}
137
138impl App {
139    pub fn builder() -> AppBuilder {
140        AppBuilder::new()
141    }
142}
143
144pub async fn run(addr: SocketAddr, app: App) -> std::io::Result<()> {
145    let listener = tokio::net::TcpListener::bind(addr).await?;
146    // axum::serve intentionally simple (и это нам подходит как внутренняя обертка)
147    axum::serve(listener, app.router).await
148}
149
150fn mount_service(
151    mut router: Router<Arc<Inner>>,
152    routes: Routes,
153    policies: &HashMap<String, Arc<Policy>>,
154    default_policy: Arc<Policy>,
155) -> Result<Router<Arc<Inner>>, String> {
156    // проверим что backend зарегистрирован
157    // (в минимальной версии лучше фейлиться сразу)
158    // В будущем тут будет ServiceRegistry/PolicyRegistry.
159    //
160    // NOTE: State<Inner> у нас хранит HashMap<String, BackendPool>.
161    // Мы проверяем наличие ключа на запросе в handler'е тоже, но ранняя проверка приятнее.
162    // Т.к. mounted routes имеют &'static str service, мы делаем check на build-time.
163    //
164    // Здесь нет доступа к Inner.backends (он внутри Arc), поэтому check сделаем в handler'е.
165    // (Можно перестроить на builder-time, но это минимальная версия.)
166
167    for rd in routes.routes {
168        let full_path = join(routes.prefix, rd.path);
169
170        let policy = resolve_policy(routes.policy, rd.policy, policies, &default_policy)?;
171
172        let meta = RouteMeta {
173            service: routes.service,
174            route_path: rd.path,
175            prefix: routes.prefix,
176            rewrite: match rd.rewrite {
177                RewriteSpec::StripPrefix => Rewrite::StripPrefix,
178                RewriteSpec::Static(to) => Rewrite::Static(FixedRewrite::new(to)),
179                RewriteSpec::Template(tpl) => Rewrite::Template(tpl),
180            },
181            policy,
182            before: rd.before,
183            map: rd.map,
184        };
185
186        let method_router = method_router(rd.method).layer(Extension(meta));
187
188        router = router.route(&full_path, method_router);
189    }
190
191    Ok(router)
192}
193
194fn resolve_policy(
195    service_policy: Option<&'static str>,
196    route_policy: Option<&'static str>,
197    registry: &HashMap<String, Arc<Policy>>,
198    default_policy: &Arc<Policy>,
199) -> Result<Arc<Policy>, String> {
200    let effective = route_policy.or(service_policy);
201
202    match effective {
203        Some(name) => registry
204            .get(name)
205            .cloned()
206            .ok_or_else(|| format!("policy `{name}` is not registered")),
207        None => Ok(default_policy.clone()),
208    }
209}
210
211fn join(prefix: &str, path: &str) -> String {
212    // prefix="/sales", path="/ping" => "/sales/ping"
213    // prefix="/sales", path="/" => "/sales/"
214    let mut s = String::with_capacity(prefix.len() + path.len());
215    if prefix.ends_with('/') {
216        s.push_str(prefix.trim_end_matches('/'));
217    } else {
218        s.push_str(prefix);
219    }
220    if path.starts_with('/') {
221        s.push_str(path);
222    } else {
223        s.push('/');
224        s.push_str(path);
225    }
226    s
227}
228
229fn method_router(method: Method) -> routing::MethodRouter<Arc<Inner>> {
230    match method {
231        Method::Get => routing::get(proxy_handler),
232        Method::Post => routing::post(proxy_handler),
233        Method::Put => routing::put(proxy_handler),
234        Method::Delete => routing::delete(proxy_handler),
235        Method::Patch => routing::patch(proxy_handler),
236        Method::Head => routing::head(proxy_handler),
237        Method::Options => routing::options(proxy_handler),
238    }
239}
240
241async fn proxy_handler(
242    State(inner): State<Arc<Inner>>,
243    Extension(meta): Extension<RouteMeta>,
244    mut req: AxumRequest,
245) -> axum::response::Response {
246    let Some(pool) = inner.backends.get(meta.service) else {
247        return bad_gateway(format!("unknown backend `{}`", meta.service));
248    };
249
250    // Before hook
251    if let Some(before) = meta.before {
252        let (mut parts, body) = req.into_parts();
253        let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
254
255        if let Err(err) = before(ctx).await {
256            return err.into_response();
257        }
258
259        req = http::Request::from_parts(parts, body);
260    }
261
262    // Map
263    if let Some(map) = meta.map {
264        let (mut parts, body) = req.into_parts();
265        let ctx = PartsCtx::new(meta.service, meta.route_path, &mut parts);
266
267        let body = match map(ctx, body, inner.map_body_limit).await {
268            Ok(body) => body,
269            Err(err) => return err.into_response(),
270        };
271
272        req = http::Request::from_parts(parts, body);
273    }
274
275    // Routing
276    let route_ctx = RouteCtx {
277        service: meta.service,
278        route_path: meta.route_path,
279        method: req.method(),
280        uri: req.uri(),
281        headers: req.headers(),
282    };
283    let routing = meta.policy.router.route(&route_ctx, pool);
284
285    // Balancer
286    let balance_ctx = BalanceCtx {
287        service: meta.service,
288        affinity: routing.affinity.as_ref(),
289        pool,
290        candidates: routing.candidates,
291    };
292    let Some(backend_index) = meta.policy.balancer.pick(&balance_ctx) else {
293        return bad_gateway("no backends selected by balancer");
294    };
295    let Some(backend) = pool.get(backend_index) else {
296        return bad_gateway("balancer returned invalid backend index");
297    };
298
299    // Make request
300    meta.policy.balancer.on_start(&StartEvent {
301        service: meta.service,
302        backend_index,
303    });
304
305    let started_at = Instant::now();
306
307    match proxy_request(backend, &inner.client, &meta, req).await {
308        Ok(response) => {
309            let elapsed = started_at.elapsed();
310
311            meta.policy.balancer.on_result(&ResultEvent {
312                service: meta.service,
313                backend_index,
314                status: Some(response.status()),
315                error: None,
316                head_latency: elapsed,
317            });
318
319            response
320        }
321        Err(error) => {
322            let elapsed = started_at.elapsed();
323
324            meta.policy.balancer.on_result(&ResultEvent {
325                service: meta.service,
326                backend_index,
327                status: None,
328                error: Some(error),
329                head_latency: elapsed,
330            });
331
332            match error {
333                ProxyErrorKind::NoBackends => bad_gateway("no backends"),
334                ProxyErrorKind::InvalidUpstreamUri => bad_gateway("bad upstream uri"),
335                ProxyErrorKind::UpstreamRequest => bad_gateway("upstream request failed"),
336            }
337        }
338    }
339}