socks-hub 0.2.6

Convert http proxy or socks5 proxy to socks5 proxy.
Documentation
use crate::{BoxError, CONNECT_TIMEOUT, Config, TokioIo, std_io_error_other};
use bytes::Bytes;
use http_body_util::{BodyExt, combinators::BoxBody};
use hyper::{
    Method, Request, Response,
    header::{AUTHORIZATION, HeaderName, HeaderValue, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION},
    service::service_fn,
    upgrade::Upgraded,
};
use socks5_impl::protocol::{Address, ProxyParameters, UserKey};
use std::net::SocketAddr;
use tokio::net::TcpListener;

#[cfg(feature = "acl")]
use crate::acl::TargetDecision;

const HTTP_DEFAULT_PORT: u16 = 80;

#[cfg(feature = "acl")]
static ACL_CENTER: std::sync::OnceLock<Option<crate::acl::AccessControl>> = std::sync::OnceLock::new();

pub async fn main_entry<F>(config: &Config, cancel_token: tokio_util::sync::CancellationToken, callback: Option<F>) -> Result<(), BoxError>
where
    F: FnOnce(SocketAddr) + Send + Sync + 'static,
{
    #[cfg(feature = "acl")]
    ACL_CENTER.get_or_init(|| {
        config
            .acl_file
            .as_ref()
            .and_then(|acl_file| match crate::acl::AccessControl::load_from_file(acl_file) {
                Ok(ac) => Some(ac),
                Err(e) => {
                    log::warn!("Could not init ACL: {e}");
                    None
                }
            })
    });

    let listen_addr: SocketAddr = config.listen_proxy_role.addr.clone().try_into()?;

    let listener = TcpListener::bind(listen_addr).await?;

    if let Some(callback) = callback {
        callback(listener.local_addr()?);
    } else {
        log::info!("Listening on {}", config.listen_proxy_role);
    }

    let config = std::sync::Arc::new(config.clone());

    loop {
        let config = config.clone();
        tokio::select! {
            _ = cancel_token.cancelled() => {
                log::info!("quit signal received");
                break;
            }
            result = listener.accept() => {
                let (stream, incoming) = result?;
                tokio::task::spawn(async move {
                    if let Err(err) = build_http_service(stream, config).await {
                        log::error!("http service on incoming {incoming} error: {err}");
                    }
                });
            }
        }
    }
    Ok(())
}

async fn build_http_service(stream: tokio::net::TcpStream, config: std::sync::Arc<Config>) -> Result<(), BoxError> {
    let io = TokioIo::new(stream);
    hyper::server::conn::http1::Builder::new()
        .preserve_header_case(true)
        .title_case_headers(true)
        .serve_connection(
            io,
            service_fn(|req: Request<hyper::body::Incoming>| {
                let config = config.clone();
                async move { proxy(req, config).await }
            }),
        )
        .with_upgrades()
        .await?;
    Ok(())
}

async fn proxy(
    mut req: Request<hyper::body::Incoming>,
    config: std::sync::Arc<Config>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, std::io::Error> {
    //
    // https://github.com/hyperium/hyper/blob/90eb95f62a32981cb662b0f750027231d8a2586b/examples/http_proxy.rs#L51
    //
    log::trace!("req: {req:?}");

    let server = config.remote_server.clone();
    let credentials = config.get_listen_credentials();
    let middle_server = config.middle_server.clone();

    fn get_proxy_authorization(req: &Request<hyper::body::Incoming>) -> (Option<HeaderName>, Option<&HeaderValue>) {
        if let Some(header) = req.headers().get(AUTHORIZATION) {
            (Some(AUTHORIZATION), Some(header))
        } else if let Some(header) = req.headers().get(PROXY_AUTHORIZATION) {
            (Some(PROXY_AUTHORIZATION), Some(header))
        } else {
            (None, None)
        }
    }

    let (auth_header, auth_value) = get_proxy_authorization(&req);
    // Some clients may omit proxy auth on the first CONNECT request and retry after a 407 challenge.
    if !is_proxy_authorized(&credentials, auth_value) {
        log::warn!("authorization fail");
        let mut resp = Response::new(empty());
        *resp.status_mut() = hyper::StatusCode::PROXY_AUTHENTICATION_REQUIRED;
        resp.headers_mut()
            .insert(PROXY_AUTHENTICATE, HeaderValue::from_static("Basic realm=\"socks-hub\""));
        return Ok(resp);
    }
    if let Some(auth_header) = auth_header {
        let _ = req.headers_mut().remove(auth_header);
    }

    if Method::CONNECT == req.method() {
        if let Some(host) = req.uri().host() {
            let port = req.uri().port_u16().unwrap_or(HTTP_DEFAULT_PORT);
            let s5addr = Address::from((host, port));

            #[cfg(feature = "acl")]
            {
                if let Some(Some(acl)) = ACL_CENTER.get() {
                    match acl.decide_target(&s5addr).await {
                        TargetDecision::Proxy | TargetDecision::Bypass => {}
                        TargetDecision::Block => {
                            let mut resp = Response::new(full("blocked by ACL"));
                            *resp.status_mut() = hyper::http::StatusCode::FORBIDDEN;
                            return Ok(resp);
                        }
                    }
                }
            }

            tokio::task::spawn(async move {
                match hyper::upgrade::on(req).await {
                    Ok(upgraded) => {
                        if let Err(e) = tunnel(upgraded, s5addr, server, middle_server).await {
                            log::error!("server io error: {e}");
                        };
                    }
                    Err(e) => log::error!("upgrade error: {e}"),
                }
            });
            Ok(Response::new(empty()))
        } else {
            log::error!("CONNECT host is not socket addr: {:?}", req.uri());
            let mut resp = Response::new(full("CONNECT must be to a socket address"));
            *resp.status_mut() = hyper::http::StatusCode::BAD_REQUEST;
            Ok(resp)
        }
    } else {
        let host = req.uri().host().unwrap_or_default();
        let port = req.uri().port_u16().unwrap_or(HTTP_DEFAULT_PORT);
        let s5addr = Address::from((host, port));

        log::debug!("destination address {s5addr}");

        #[cfg(feature = "acl")]
        {
            let mut must_proxied = true;
            if let Some(Some(acl)) = ACL_CENTER.get() {
                match acl.decide_target(&s5addr).await {
                    TargetDecision::Proxy => must_proxied = true,
                    TargetDecision::Bypass => must_proxied = false,
                    TargetDecision::Block => {
                        let mut resp = Response::new(full("blocked by ACL"));
                        *resp.status_mut() = hyper::http::StatusCode::FORBIDDEN;
                        return Ok(resp);
                    }
                }
            }
            if !must_proxied {
                log::debug!("connect to destination address {s5addr:?} without proxy");
                let stream = tokio::net::TcpStream::connect((host, port)).await?;
                return proxy_internal(stream, req).await;
            }
        }

        log::debug!("connect to SOCKS5 proxy server {server:?}");
        let stream = crate::create_s5_connect(server, CONNECT_TIMEOUT, &s5addr, middle_server).await?;
        proxy_internal(stream, req).await
    }
}

async fn proxy_internal<S>(stream: S, req: Request<hyper::body::Incoming>) -> Result<Response<BoxBody<Bytes, hyper::Error>>, std::io::Error>
where
    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + 'static + Unpin,
{
    let io = TokioIo::new(stream);
    let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
        .preserve_header_case(true)
        .title_case_headers(true)
        .handshake(io)
        .await
        .map_err(std_io_error_other)?;
    tokio::task::spawn(async move {
        if let Err(err) = conn.await {
            log::error!("Connection failed: {err:?}");
        }
    });
    let resp = sender.send_request(req).await.map_err(std_io_error_other)?;
    Ok(resp.map(|b| b.boxed()))
}

fn empty() -> BoxBody<Bytes, hyper::Error> {
    http_body_util::Empty::<Bytes>::new().map_err(|never| match never {}).boxed()
}

fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
    http_body_util::Full::new(chunk.into()).map_err(|never| match never {}).boxed()
}

// Create a TCP connection to host:port, build a tunnel between the connection and
// the upgraded connection
async fn tunnel(upgraded: Upgraded, dst: Address, server: ProxyParameters, middle_server: Option<ProxyParameters>) -> std::io::Result<()> {
    #[cfg(feature = "acl")]
    {
        let mut must_proxied = true;
        if let Some(Some(acl)) = ACL_CENTER.get() {
            match acl.decide_target(&dst).await {
                TargetDecision::Proxy => must_proxied = true,
                TargetDecision::Bypass => must_proxied = false,
                TargetDecision::Block => {
                    return Err(std_io_error_other("blocked by ACL"));
                }
            }
        }
        if !must_proxied {
            log::debug!("connect to destination address {dst:?} without proxy");
            let mut upgraded = TokioIo::new(upgraded);
            use std::net::ToSocketAddrs;
            let addr = dst.to_socket_addrs()?.next().ok_or(std_io_error_other("no address found"))?;
            let mut server = tokio::net::TcpStream::connect(addr).await?;
            let (from_client, from_server) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
            log::debug!("client wrote {from_client} bytes and received {from_server} bytes");
            return Ok(());
        }
    }

    let mut upgraded = TokioIo::new(upgraded);
    let mut server = crate::create_s5_connect(server, CONNECT_TIMEOUT, &dst, middle_server).await?;
    let (from_client, from_server) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
    log::debug!("client wrote {from_client} bytes and received {from_server} bytes");
    Ok(())
}

fn verify_basic_authorization(credentials: &UserKey, header_value: Option<&HeaderValue>) -> bool {
    if header_value.is_none() && credentials.to_string().is_empty() {
        return true;
    }
    header_value
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.strip_prefix("Basic "))
        .and_then(|v| base64easy::decode(v, base64easy::EngineKind::Standard).ok())
        .is_some_and(|v| v == credentials.to_string().as_bytes().to_vec())
}

fn is_proxy_authorized(credentials: &UserKey, header_value: Option<&HeaderValue>) -> bool {
    credentials.to_string().is_empty() || verify_basic_authorization(credentials, header_value)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn connect_requires_proxy_auth_when_credentials_are_configured() {
        let credentials = UserKey::new("alice", "secret");
        assert!(!is_proxy_authorized(&credentials, None));
    }

    #[test]
    fn connect_allows_missing_proxy_auth_when_no_credentials_are_configured() {
        let credentials = UserKey::default();
        assert!(is_proxy_authorized(&credentials, None));
    }
}