spacegate-kernel 0.2.0-alpha.4

A library-first, lightweight, high-performance, cloud-native supported API gateway
Documentation
use crate::{extension::Reflect, SgBody, SgResponseExt};

use hyper::StatusCode;
use hyper::{Request, Response};
use hyper_rustls::HttpsConnector;
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use hyper_util::{
    client::legacy::{connect::HttpConnector, Client},
    rt::TokioExecutor,
};
use std::{
    collections::HashMap,
    sync::{Arc, Mutex, OnceLock},
    time::Duration,
};
use tokio_rustls::rustls::{self, client::danger::ServerCertVerifier, SignatureScheme};

#[derive(Debug, Clone)]
pub struct NoCertificateVerification {}
impl ServerCertVerifier for NoCertificateVerification {
    fn verify_server_cert(
        &self,
        _end_entity: &rustls::pki_types::CertificateDer<'_>,
        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
        _server_name: &rustls::pki_types::ServerName<'_>,
        _ocsp_response: &[u8],
        _now: rustls::pki_types::UnixTime,
    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
        Ok(rustls::client::danger::ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &rustls::pki_types::CertificateDer<'_>,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &rustls::pki_types::CertificateDer<'_>,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
        vec![
            SignatureScheme::RSA_PKCS1_SHA256,
            SignatureScheme::RSA_PKCS1_SHA384,
            SignatureScheme::RSA_PKCS1_SHA512,
            SignatureScheme::RSA_PSS_SHA256,
            SignatureScheme::RSA_PSS_SHA384,
            SignatureScheme::RSA_PSS_SHA512,
            SignatureScheme::ECDSA_NISTP256_SHA256,
            SignatureScheme::ECDSA_NISTP384_SHA384,
            SignatureScheme::ECDSA_NISTP521_SHA512,
            SignatureScheme::ED25519,
            SignatureScheme::ED448,
        ]
    }
}

fn get_rustls_config_dangerous() -> rustls::ClientConfig {
    let store = rustls::RootCertStore::empty();
    let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
    let mut config = rustls::ClientConfig::builder().with_root_certificates(store).with_no_client_auth();
    // completely disable cert-verification
    let mut dangerous_config = rustls::ClientConfig::dangerous(&mut config);
    dangerous_config.set_certificate_verifier(Arc::new(NoCertificateVerification {}));
    config
}

pub fn get_client() -> HttpClient {
    ClientRepo::global().get_default()
}

#[derive(Debug)]
pub struct ClientRepo {
    default: HttpClient,
    repo: Mutex<HashMap<String, HttpClient>>,
}

impl Default for ClientRepo {
    fn default() -> Self {
        let config = get_rustls_config_dangerous();
        let default = HttpClient::new(config);
        Self {
            default,
            repo: Default::default(),
        }
    }
}

static mut GLOBAL: OnceLock<ClientRepo> = OnceLock::new();
impl ClientRepo {
    pub fn get(&self, code: &str) -> Option<HttpClient> {
        self.repo.lock().expect("failed to lock client repo").get(code).cloned()
    }
    pub fn get_or_default(&self, code: &str) -> HttpClient {
        self.get(code).unwrap_or_else(|| self.default.clone())
    }
    pub fn get_default(&self) -> HttpClient {
        self.default.clone()
    }
    pub fn register(&self, code: &str, client: HttpClient) {
        self.repo.lock().expect("failed to lock client repo").insert(code.to_string(), client);
    }
    pub fn set_default(&mut self, client: HttpClient) {
        self.default = client;
    }
    pub fn global() -> &'static Self {
        unsafe { std::ptr::addr_of!(GLOBAL).cast_mut().as_mut().expect("invalid static global client repo") }.get_or_init(Default::default)
    }

    /// # Safety
    /// This function is not thread safe, it should be called before any other thread is spawned.
    pub unsafe fn set_global_default(client: HttpClient) {
        GLOBAL.get_or_init(Default::default);
        GLOBAL.get_mut().expect("global not set").set_default(client);
    }
}
#[derive(Debug)]
pub struct SgHttpClientConfig {
    pub tls_config: rustls::ClientConfig,
}

#[derive(Debug, Clone)]
pub struct HttpClient {
    inner: Client<HttpsConnector<HttpConnector>, SgBody>,
}

impl Default for HttpClient {
    fn default() -> Self {
        Self::new(rustls::ClientConfig::builder().with_native_roots().expect("failed to init rustls config").with_no_client_auth())
    }
}

impl HttpClient {
    pub fn new(tls_config: rustls::ClientConfig) -> Self {
        HttpClient {
            inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().enable_http2().build()),
        }
    }
    pub fn new_h1_only(tls_config: rustls::ClientConfig) -> Self {
        HttpClient {
            inner: Client::builder(TokioExecutor::new()).build(HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http().enable_http1().build()),
        }
    }
    pub fn new_dangerous() -> Self {
        let config = get_rustls_config_dangerous();
        Self::new(config)
    }
    pub fn new_dangerous_h1_only() -> Self {
        let config = get_rustls_config_dangerous();
        Self::new_h1_only(config)
    }
    pub async fn request(&mut self, mut req: Request<SgBody>) -> Response<SgBody> {
        let reflect = req.extensions_mut().remove::<Reflect>();
        match self.inner.request(req).await.map_err(Response::bad_gateway) {
            Ok(mut response) => {
                if let Some(reflect) = reflect {
                    response.extensions_mut().extend(reflect.into_inner());
                }
                response.map(SgBody::new)
            }
            Err(err) => err,
        }
    }
    pub async fn request_timeout(&mut self, req: Request<SgBody>, timeout: Duration) -> Response<SgBody> {
        let fut = self.request(req);
        let resp = tokio::time::timeout(timeout, fut).await;
        match resp {
            Ok(resp) => resp,
            Err(_) => Response::with_code_message(StatusCode::GATEWAY_TIMEOUT, "request timeout"),
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    #[tokio::test]
    async fn test_client() {
        let mut client = get_client();
        let req = Request::builder().uri("https://www.baidu.com").body(SgBody::empty()).unwrap();
        let resp = client.request(req).await;
        let (part, body) = resp.into_parts();
        let body = body.dump().await.unwrap();
        let dumped = body.get_dumped().expect("no body");
        println!("{part:?}, {}", String::from_utf8_lossy(dumped));
    }
}