dhomer 0.1.0

Simple and easy to use, a proxy server based on Pingora
use crate::config::router_pattern::RouterPattern;
use async_trait::async_trait;
use http::uri::PathAndQuery;
use http::Uri;
use pingora::http::RequestHeader;
use pingora::protocols::l4::socket::SocketAddr;
use pingora::{
    lb::LoadBalancer,
    prelude::{HttpPeer, RoundRobin},
    proxy::{ProxyHttp, Session},
    ErrorType::Custom,
    OkOrErr,
};
use snafu::Snafu;
use std::str::FromStr;
use std::sync::Arc;

#[derive(Debug, Snafu)]
pub enum Error {
    #[snafu(whatever, display("{message}"))]
    Whatever {
        message: String,
        #[snafu(source(from(Box<dyn std::error::Error + Send + Sync>, Some)))]
        source: Option<Box<dyn std::error::Error>>,
    },
}

pub struct Proxy {
    backends: Arc<Vec<(RouterPattern, Arc<LoadBalancer<RoundRobin>>)>>,
}

impl Proxy {
    pub fn new(lbs: impl Iterator<Item = (RouterPattern, Arc<LoadBalancer<RoundRobin>>)>) -> Self {
        Self {
            backends: Arc::new(lbs.collect()),
        }
    }

    pub fn determine_load_balancer(
        &self,
        session: &Session,
    ) -> Option<Arc<LoadBalancer<RoundRobin>>> {
        let req = session.req_header();

        let path = req.uri.path();
        self.backends
            .iter()
            .find(|(p, _lb)| p.is_match(path))
            .map(|(_p, lb)| lb.clone())
    }
}

#[async_trait]
impl ProxyHttp for Proxy {
    type CTX = ();

    fn new_ctx(&self) -> Self::CTX {
        
    }

    async fn upstream_peer(&self, session: &mut Session, _ctx: &mut Self::CTX) -> pingora::Result<Box<HttpPeer>> {
        let req = session.req_header_mut();
        let parts = req.uri.clone().into_parts();
        let lb = if let Some(pq) = parts.path_and_query {
            let (rest, lb) = self.backends
                .iter()
                .find_map(|(p, lb)| p.rest_after_match(pq.path()).map(|r| (r, lb)))
                .or_err(Custom("NoTarget"),
                    "Failed to get target upstream for session",
                )?;

            let new_pq_str = pq.query()
                .map_or_else(|| rest.clone(), |q| format!("{}?{}", rest, q));
            let new_pq = PathAndQuery::from_str(&new_pq_str).unwrap();

            let mut uri_builder = Uri::builder();
            if let Some(scheme) = parts.scheme {
                uri_builder = uri_builder.scheme(scheme);
            }
            if let Some(authority) = parts.authority {
                uri_builder = uri_builder.authority(authority);
            }
            uri_builder = uri_builder.path_and_query(new_pq);
            let new_uri = uri_builder.build().unwrap();

            req.set_uri(new_uri);

            lb
        } else {
            let (_pattern, lb) =self.backends
                .iter()
                .find(|(p, lb)| p.is_match(""))
                .or_err(Custom("NoTarget"),
                        "Failed to get target upstream for session",
                )?;

            lb
        };

        let backend = lb.select(b"", 256).or_err(
            Custom("NoAvailable"),
            "Failed to get available backend for session",
        )?;

        let peer = Box::new(HttpPeer::new(backend, false, "".to_string()));
        Ok(peer)
    }

    async fn upstream_request_filter(&self, session: &mut Session, upstream_request: &mut RequestHeader, _ctx: &mut Self::CTX) -> pingora::Result<()>
    where
        Self::CTX: Send + Sync,
    {
        if let Some(remote_addr) = session.client_addr() {
            #[allow(irrefutable_let_patterns)]
            if let SocketAddr::Inet(addr) = remote_addr {
                let client_ip = addr.ip().to_string();

                // 内层代理:仅在X-Real-Ip不存在时设置(避免覆盖外层代理的值)
                if upstream_request.headers.get("X-Real-Ip").is_none() {
                    upstream_request.insert_header("X-Real-Ip", &client_ip)?;
                }

                // X-Forwarded-For:始终追加当前代理感知到的IP(维持完整链路)
                let x_forwarded_for = upstream_request.headers
                    .get("X-Forwarded-For")
                    .and_then(|x| x.to_str().ok())
                    .map_or(client_ip.clone(), |v| format!("{}, {}", v, client_ip));
                upstream_request.insert_header("X-Forwarded-For", x_forwarded_for)?;
            }
        }

        Ok(())
    }
}