product-os-proxy 0.0.17

Product OS : Proxy builds on the work of hudsucker, taking it to the next level with a man-in-the-middle proxy server that can tunnel traffic through a VPN utilising Product OS : VPN.
mod internal;

pub mod builder;

pub use builder::ProxyBuilder;

use crate::mitm::{certificate_authority::CertificateAuthority, Error, HttpHandler, WebSocketHandler};

use builder::AddrOrListener;

use hyper_util::{
    client::legacy::{
        connect::{Connect, HttpConnector},
        Client,
    },
    rt::{TokioExecutor, TokioIo},
    server::conn::auto,
    server::conn::auto::Builder,
};

use internal::InternalProxy;
use std::{future::Future, sync::Arc};
use std::future::Pending;
use tokio::net::TcpListener;
use tokio_graceful::Shutdown;
use tokio_tungstenite::Connector;

// use product_os_server::service_fn;
use hyper::service::service_fn;
use hyper_rustls::HttpsConnector;
use product_os_http_body::{BodyBytes as Body, BodyBytes};

#[cfg(feature = "tor")]
use tor_rtcompat::PreferredRuntime;

use crate::mitm::certificate_authority::RcgenAuthority;
use crate::ProxyMiddleware;

pub struct Proxy<C, CA, H, W, F> {
    address_or_listener: AddrOrListener,
    ca: Arc<CA>,
    client: Client<C, BodyBytes>,
    websocket_connector: Option<Connector>,

    http_handler: H,
    websocket_handler: W,

    certificates: product_os_security::certificates::Certificates,
    custom_requester: Option<product_os_request::ProductOSRequestClient>,
    compression: product_os_configuration::NetworkProxyCompression,

    server: Option<Builder<TokioExecutor>>,
    graceful_shutdown: F,

    #[cfg(feature = "tor")]
    tor_client: Option<arti_client::TorClient<tor_rtcompat::PreferredRuntime>>,
    #[cfg(feature = "vpn")]
    vpn_client: Option<product_os_vpn::ProductOSVPN>
}

impl Proxy<(), RcgenAuthority, ProxyMiddleware, ProxyMiddleware, Pending<()>> {
    pub fn https_builder() -> ProxyBuilder<HttpsConnector<HttpConnector>, RcgenAuthority, ProxyMiddleware, ProxyMiddleware, Pending<()>> {
        ProxyBuilder::new()
    }

    pub fn http_builder() -> ProxyBuilder<HttpConnector, RcgenAuthority, ProxyMiddleware, ProxyMiddleware, Pending<()>> {
        ProxyBuilder::new()
    }
}

impl<C, CA, H, W, F> Proxy<C, CA, H, W, F>
where
    C: Connect + Clone + Send + Sync + 'static,
    CA: CertificateAuthority,
    H: HttpHandler + Clone,
    W: WebSocketHandler + Clone,
    F: Future<Output = ()> + Send + 'static,
{

    pub async fn start(self) -> Result<(), Error> {
        let listener = match self.address_or_listener {
            AddrOrListener::Addr(addr) => TcpListener::bind(addr).await?,
            AddrOrListener::Listener(listener) => listener,
        };

        let ca = Arc::clone(&self.ca);

        let client = self.client.clone();
        let websocket_connector = self.websocket_connector.clone();

        let http_handler = self.http_handler.clone();
        let websocket_handler = self.websocket_handler.clone();

        let shutdown = Shutdown::new(self.graceful_shutdown);
        let guard = shutdown.guard_weak();

        #[cfg(feature = "tor")]
        let tor_client = self.tor_client.clone();

        #[cfg(feature = "vpn")]
        let vpn_client = self.vpn_client.clone();

        let server = self.server.unwrap_or_else(|| {
            let mut builder = Builder::new(TokioExecutor::new());
            builder
                .http1()
                .title_case_headers(true)
                .preserve_header_case(true)
                .http2();
                //.enable_connect_protocol();
            builder
        });


        loop {
            let custom_requester = match &self.custom_requester {
                Some(custom_requester) => Some(custom_requester.clone()),
                None => None
            };

            let compression = self.compression.clone();

            tokio::select! {
                res = listener.accept() => {
                    let (tcp, client_addr) = match res {
                        Ok((tcp, client_addr)) => (tcp, client_addr),
                        Err(e) => {
                            tracing::error!("Failed to accept incoming connection: {}", e);
                            continue;
                        }
                    };

                    let server = server.clone();
                    let client = client.clone();
                    let ca = Arc::clone(&ca);
                    let http_handler = http_handler.clone();
                    let websocket_handler = websocket_handler.clone();
                    let websocket_connector = websocket_connector.clone();

                    #[cfg(feature = "tor")]
                    let tor_client = tor_client.clone();
                    #[cfg(feature = "vpn")]
                    let vpn_client = vpn_client.clone();

                    shutdown.spawn_task_fn(move |guard| async move {
                        let conn = server.serve_connection_with_upgrades(
                            TokioIo::new(tcp),
                            service_fn(|req| {
                                InternalProxy {
                                    ca: Arc::clone(&ca),
                                    client: client.clone(),
                                    http_handler: http_handler.clone(),
                                    websocket_handler: websocket_handler.clone(),
                                    websocket_connector: websocket_connector.clone(),
                                    server: server.clone(),
                                    client_addr,
                                    #[cfg(feature = "tor")]
                                    tor_client: tor_client.clone(),
                                    #[cfg(feature = "vpn")]
                                    vpn_client: vpn_client.clone(),
                                    custom_requester: custom_requester.clone(),
                                    compression: compression.clone()
                                }
                                .proxy(req)
                            }),
                        );


                        let mut conn = std::pin::pin!(conn);

                        if let Err(err) = tokio::select! {
                            conn = conn.as_mut() => conn,
                            _ = guard.cancelled() => {
                                conn.as_mut().graceful_shutdown();
                                conn.await
                            }
                        } {
                            tracing::error!("Error serving connection: {}", err);
                        }
                    });
                }
                _ = guard.cancelled() => {
                    break;
                }
            }
        }

        Ok(())
    }
}