reqwest_lb/
middleware.rs

1use crate::lb::LoadBalancerRegistry;
2use crate::BoxError;
3use async_trait::async_trait;
4use http::Extensions;
5use reqwest::{Request, Response, Url};
6use reqwest_middleware::{Middleware, Next};
7use std::fmt::Debug;
8use thiserror::Error;
9use tracing::debug;
10
11fn is_lb_schema(schema: &str) -> bool {
12    match (schema.get(0..1), schema.get(1..2)) {
13        (Some(a), Some(b)) => (a == "l" || a == "L") && (b == "b" || b == "B"),
14        _ => false,
15    }
16}
17
18pub struct LoadBalancerMiddleware<I, E> {
19    registry: LoadBalancerRegistry<I, E>,
20}
21
22impl<I, E> LoadBalancerMiddleware<I, E> {
23    pub fn new(registry: LoadBalancerRegistry<I, E>) -> Self {
24        Self { registry }
25    }
26}
27
28#[async_trait]
29impl<I, E, IE> Middleware for LoadBalancerMiddleware<I, E>
30where
31    I: TryInto<Url, Error = IE> + 'static,
32    IE: Into<BoxError> + 'static,
33    E: Into<BoxError> + 'static,
34{
35    async fn handle(
36        &self,
37        mut request: Request,
38        extensions: &mut Extensions,
39        next: Next<'_>,
40    ) -> reqwest_middleware::Result<Response> {
41        let schema = request.url().scheme();
42        if is_lb_schema(schema) {
43            let host = request.url().host_str().ok_or(Error::MissHost)?;
44            let load_balancer = self
45                .registry
46                .find(host)
47                .ok_or(Error::NotFoundLoadBalancer)?;
48            let item = load_balancer
49                .choose(extensions)
50                .await
51                .map_err(|e| Error::Customize(e.into()))?
52                .ok_or(Error::NotFoundElement)?;
53            let source = request.url();
54            let mut target = item.try_into().map_err(|e| Error::InvalidUrl(e.into()))?;
55            reconstruct(source, &mut target);
56            debug!("reconstruct new url: {}", target.as_str());
57            *request.url_mut() = target;
58        }
59        next.run(request, extensions).await
60    }
61}
62
63fn reconstruct(source: &Url, target: &mut Url) {
64    target.set_path(source.path());
65    target.set_query(source.query());
66    target.set_fragment(source.fragment());
67}
68
69#[derive(Debug, Error)]
70pub enum Error {
71    #[error("Invalid url: {0}")]
72    InvalidUrl(BoxError),
73
74    #[error("Registry not found load balancer")]
75    NotFoundLoadBalancer,
76
77    #[error("Load balancer not found element")]
78    NotFoundElement,
79
80    #[error("Request miss host")]
81    MissHost,
82
83    #[error("{0}")]
84    Customize(BoxError),
85}
86
87impl Error {
88    pub fn customize<E: Into<BoxError>>(error: E) -> Self {
89        Self::Customize(error.into())
90    }
91}
92
93impl From<Error> for reqwest_middleware::Error {
94    fn from(value: Error) -> Self {
95        Self::middleware(value)
96    }
97}