otoroshictl 0.0.17

a CLI to manage your otoroshi clusters with style ;)
use futures_util::Future;
use http::Response;
use hyper::client::{HttpConnector, ResponseFuture};
use hyper::server::conn::AddrIncoming;
use hyper::service::Service;
use hyper::{Body, Client, Request, Server};
use hyper_rustls::TlsAcceptor;

use super::cache::{OtoroshiChallengePlugin, SidecarCache};
use super::config::OtoroshiSidecarConfig;

use crate::otoroshi::protocol::OtoroshiProtocol;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};

#[derive(Clone, Debug)]
struct InboundProxyClient {
    config: OtoroshiSidecarConfig,
    http_client: Client<HttpConnector>,
}

impl InboundProxyClient {
    fn new(config: OtoroshiSidecarConfig) -> InboundProxyClient {
        let http_client = Client::new();
        InboundProxyClient {
            config,
            http_client,
        }
    }

    fn request(&self, mut req: Request<Body>) -> ResponseFuture {
        let uri_string = format!(
            "{}://{}:{}{}",
            "http",
            "127.0.0.1",
            self.config.spec.inbound.target_port.unwrap_or(8080),
            req.uri()
                .path_and_query()
                .map(|x| x.as_str())
                .unwrap_or("/")
        );
        let uri = uri_string.parse().unwrap();
        *req.uri_mut() = uri;
        let version = match self.config.clone().spec.inbound.target_version {
            Some(v) if v == "h2" => http::version::Version::HTTP_2,
            _ => http::version::Version::HTTP_11,
        };
        *req.version_mut() = version;
        self.http_client.request(req)
    }
}

pub struct InboundProxy {}

impl InboundProxy {
    pub fn start_http(
        config: OtoroshiSidecarConfig,
        cache: Arc<SidecarCache>,
    ) -> impl Future<Output = std::result::Result<(), hyper::Error>> {
        let in_addr: SocketAddr = SocketAddr::new(
            "0.0.0.0".parse().unwrap(),
            config.spec.inbound.port.unwrap_or(15000),
        );
        let client = InboundProxyClient::new(config.clone());
        let make_svc = MakeSvc {
            client,
            config,
            cache,
        };
        let server = Server::bind(&in_addr).serve(make_svc);
        info!(target: "inbound_proxy", "listening on http://{}", in_addr);
        server
    }

    pub fn start_https(
        config: OtoroshiSidecarConfig,
        cache: Arc<SidecarCache>,
    ) -> impl Future<Output = std::result::Result<(), hyper::Error>> {
        let in_addr: SocketAddr = SocketAddr::new(
            "0.0.0.0".parse().unwrap(),
            config.spec.inbound.port.unwrap_or(15000),
        );
        let client = InboundProxyClient::new(config.clone());
        let mtls = config
            .clone()
            .spec
            .inbound
            .mtls
            .map(|i| i.enabled)
            .unwrap_or(false);
        let cert_id: String = config
            .clone()
            .spec
            .inbound
            .tls
            .map(|i| i.cert_id)
            .unwrap_or("none".to_string());
        let certificate = cache.wait_and_get_cert_by_id(cert_id.clone(), 30);
        let make_svc = MakeSvc {
            client,
            config: config.clone(),
            cache: cache.clone(),
        };
        if mtls {
            let incoming = AddrIncoming::bind(&in_addr).unwrap();
            let mut root = rustls::RootCertStore::empty();
            let ca_id = config
                .clone()
                .spec
                .inbound
                .mtls
                .map(|i| i.ca_cert_id)
                .unwrap_or("none".to_string());
            let ca = cache.wait_and_get_cert_by_id(ca_id.clone(), 30);
            for cert in ca.certs().into_iter() {
                let _ = root.add(&cert);
            }
            let client_auth = Arc::new(rustls::server::AllowAnyAuthenticatedClient::new(root));
            let server_config = rustls::ServerConfig::builder()
                .with_safe_default_cipher_suites()
                .with_safe_default_kx_groups()
                .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
                .unwrap()
                .with_client_cert_verifier(client_auth)
                .with_single_cert(certificate.certs(), certificate.key())
                .unwrap();
            let acceptor = TlsAcceptor::builder()
                .with_tls_config(server_config)
                .with_alpn_protocols(vec![b"h2".to_vec(), b"http/1.1".to_vec()])
                .with_incoming(incoming);
            let server = Server::builder(acceptor).serve(make_svc);
            info!(target: "inbound_proxy", "listening on https://{} with mTLS", in_addr);
            server
        } else {
            let incoming = AddrIncoming::bind(&in_addr).unwrap();
            let server_config = rustls::ServerConfig::builder()
                .with_safe_default_cipher_suites()
                .with_safe_default_kx_groups()
                .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
                .unwrap()
                .with_no_client_auth()
                .with_single_cert(certificate.certs(), certificate.key())
                .unwrap();
            let acceptor = TlsAcceptor::builder()
                .with_tls_config(server_config)
                .with_alpn_protocols(vec![b"h2".to_vec(), b"http/1.1".to_vec()])
                .with_incoming(incoming);
            let server = Server::builder(acceptor).serve(make_svc);
            info!(target: "inbound_proxy", "listening on https://{}", in_addr);
            server
        }
    }
}

struct Svc {
    client: InboundProxyClient,
    config: OtoroshiSidecarConfig,
    cache: Arc<SidecarCache>,
}

impl Svc {
    /// Process Otoroshi protocol challenge using the shared module.
    fn process_otoroshi_protocol(
        version: &str,
        state_value: &str,
        secret_in: &str,
        algo_in: &str,
        secret_out: &str,
        algo_out: &str,
    ) -> Option<String> {
        if version == "V2" {
            let protocol = OtoroshiProtocol::new_asymmetric(
                secret_in.as_bytes(),
                algo_in.parse().unwrap_or_default(),
                secret_out.as_bytes(),
                algo_out.parse().unwrap_or_default(),
            );
            match protocol.process_v2(state_value) {
                Ok(response_token) => Some(response_token),
                Err(e) => {
                    error!("Otoroshi protocol error: {}", e);
                    None
                }
            }
        } else {
            // V1: Simple echo
            Some(state_value.to_string())
        }
    }

    fn get_additional_headers(
        &self,
        in_headers: http::HeaderMap,
    ) -> HashMap<http::HeaderName, http::HeaderValue> {
        let mut headers = HashMap::new();
        if self
            .config
            .clone()
            .spec
            .inbound
            .otoroshi_protocol
            .map(|i| i.enabled)
            .unwrap_or(false)
        {
            let proto = self.config.clone().spec.inbound.otoroshi_protocol.unwrap();
            match proto.route_id {
                None => {
                    let version = proto.version.unwrap_or("V2".to_string());
                    let secret_in = proto.secret_in.unwrap_or("secret".to_string());
                    let algo_in = proto.algo_in.unwrap_or("HS512".to_string());
                    let secret_out = proto.secret_out.unwrap_or("secret".to_string());
                    let algo_out = proto.algo_out.unwrap_or("HS512".to_string());
                    let header_name_in = proto
                        .header_in_name
                        .map(|i| i.to_ascii_lowercase())
                        .unwrap_or("otoroshi-state".to_string());
                    let header_name_out = proto
                        .header_out_name
                        .map(|i| i.to_ascii_lowercase())
                        .unwrap_or("otoroshi-state-resp".to_string());

                    if let Some(state) = in_headers.get(&header_name_in)
                        && let Ok(state_str) = state.to_str()
                        && let Some(response) = Self::process_otoroshi_protocol(
                            &version,
                            state_str,
                            &secret_in,
                            &algo_in,
                            &secret_out,
                            &algo_out,
                        )
                        && let (Ok(name), Ok(value)) = (
                            http::HeaderName::from_str(&header_name_out),
                            http::HeaderValue::from_str(&response),
                        )
                    {
                        headers.insert(name, value);
                    }
                }
                Some(route_id) => match self.cache.get_route_by_id(route_id.clone()) {
                    None => error!("route with id '{}' does not exist", route_id.clone()),
                    Some(route) => {
                        match route.plugins.into_iter().find(|plugin| {
                            plugin.plugin == "cp:otoroshi.next.plugins.OtoroshiChallenge"
                        }) {
                            None => error!(
                                "the specified route with id '{}' does not have the OtoroshiChallenge plugin",
                                route_id
                            ),
                            Some(plugin) => {
                                if let Ok(config) =
                                    serde_json::from_value::<OtoroshiChallengePlugin>(plugin.config)
                                {
                                    let version = config.version;
                                    let secret_in = config.algo_to_backend.secret;
                                    let algo_in = format!("HS{}", config.algo_to_backend.size);
                                    let secret_out = config.algo_from_backend.secret;
                                    let algo_out = format!("HS{}", config.algo_from_backend.size);
                                    let header_name_in = config
                                        .request_header_name
                                        .unwrap_or("otoroshi-state".to_string());
                                    let header_name_out = config
                                        .response_header_name
                                        .unwrap_or("otoroshi-state-resp".to_string());

                                    if let Some(state) = in_headers.get(&header_name_in)
                                        && let Ok(state_str) = state.to_str()
                                        && let Some(response) = Self::process_otoroshi_protocol(
                                            &version,
                                            state_str,
                                            &secret_in,
                                            &algo_in,
                                            &secret_out,
                                            &algo_out,
                                        )
                                        && let (Ok(name), Ok(value)) = (
                                            http::HeaderName::from_str(&header_name_out),
                                            http::HeaderValue::from_str(&response),
                                        )
                                    {
                                        headers.insert(name, value);
                                    }
                                }
                            }
                        }
                    }
                },
            };
        }
        headers
    }
}

impl Service<Request<Body>> for Svc {
    type Response = Response<Body>;
    type Error = hyper::Error;
    type Future = ResponseFuture;

    fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
        let mut headers = req.headers().clone();
        let default_value = http::HeaderValue::from_str("localhost").unwrap();
        let authority_value = http::HeaderValue::from_str(
            req.uri()
                .authority()
                .map(|a| a.host())
                .unwrap_or("localhost"),
        )
        .unwrap();
        let orig_host = headers
            .get("Host")
            .or(headers.get(":authority"))
            .or(Some(&authority_value))
            .unwrap_or(&default_value);
        let host = match self.config.clone().spec.inbound.target_hostname {
            None => orig_host.to_owned(),
            Some(hostname) => http::HeaderValue::from_str(hostname.clone().as_str()).unwrap(),
        };
        let host_str = host.to_str().unwrap().to_string();
        let h_map = http::HeaderMap::from_iter(self.get_additional_headers(headers.clone()));
        info!(target: "inbound_proxy", "{} https://{}{}, otoroshi_protocol: {}", req.method(), host_str, req.uri().path(), if h_map.is_empty() { "no" } else { "yes" });
        headers.insert("Host", host.clone());
        headers.extend(h_map);
        *req.headers_mut() = headers;
        self.client.request(req)
    }
}

struct MakeSvc {
    client: InboundProxyClient,
    config: OtoroshiSidecarConfig,
    cache: Arc<SidecarCache>,
}

impl<T> Service<T> for MakeSvc {
    type Response = Svc;
    type Error = hyper::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, _: T) -> Self::Future {
        let client = self.client.clone();
        let config = self.config.clone();
        let cache = self.cache.clone();
        let future = async move {
            Ok(Svc {
                client,
                config,
                cache,
            })
        };
        Box::pin(future)
    }
}