1use crate::lb::LoadBalancerRegistry;
2use crate::BoxError;
3use async_trait::async_trait;
4use http::Extensions;
5use reqwest::{Request, Response, Url};
6use reqwest_middleware::{Middleware, Next};
7use std::fmt::Debug;
8use thiserror::Error;
9use tracing::debug;
10
11fn is_lb_schema(schema: &str) -> bool {
12 match (schema.get(0..1), schema.get(1..2)) {
13 (Some(a), Some(b)) => (a == "l" || a == "L") && (b == "b" || b == "B"),
14 _ => false,
15 }
16}
17
18pub struct LoadBalancerMiddleware<I, E> {
19 registry: LoadBalancerRegistry<I, E>,
20}
21
22impl<I, E> LoadBalancerMiddleware<I, E> {
23 pub fn new(registry: LoadBalancerRegistry<I, E>) -> Self {
24 Self { registry }
25 }
26}
27
28#[async_trait]
29impl<I, E, IE> Middleware for LoadBalancerMiddleware<I, E>
30where
31 I: TryInto<Url, Error = IE> + 'static,
32 IE: Into<BoxError> + 'static,
33 E: Into<BoxError> + 'static,
34{
35 async fn handle(
36 &self,
37 mut request: Request,
38 extensions: &mut Extensions,
39 next: Next<'_>,
40 ) -> reqwest_middleware::Result<Response> {
41 let schema = request.url().scheme();
42 if is_lb_schema(schema) {
43 let host = request.url().host_str().ok_or(Error::MissHost)?;
44 let load_balancer = self
45 .registry
46 .find(host)
47 .ok_or(Error::NotFoundLoadBalancer)?;
48 let item = load_balancer
49 .choose(extensions)
50 .await
51 .map_err(|e| Error::Customize(e.into()))?
52 .ok_or(Error::NotFoundElement)?;
53 let source = request.url();
54 let mut target = item.try_into().map_err(|e| Error::InvalidUrl(e.into()))?;
55 reconstruct(source, &mut target);
56 debug!("reconstruct new url: {}", target.as_str());
57 *request.url_mut() = target;
58 }
59 next.run(request, extensions).await
60 }
61}
62
63fn reconstruct(source: &Url, target: &mut Url) {
64 target.set_path(source.path());
65 target.set_query(source.query());
66 target.set_fragment(source.fragment());
67}
68
69#[derive(Debug, Error)]
70pub enum Error {
71 #[error("Invalid url: {0}")]
72 InvalidUrl(BoxError),
73
74 #[error("Registry not found load balancer")]
75 NotFoundLoadBalancer,
76
77 #[error("Load balancer not found element")]
78 NotFoundElement,
79
80 #[error("Request miss host")]
81 MissHost,
82
83 #[error("{0}")]
84 Customize(BoxError),
85}
86
87impl Error {
88 pub fn customize<E: Into<BoxError>>(error: E) -> Self {
89 Self::Customize(error.into())
90 }
91}
92
93impl From<Error> for reqwest_middleware::Error {
94 fn from(value: Error) -> Self {
95 Self::middleware(value)
96 }
97}