redact-client 1.5.9

Receives request for private data and decrypts it to display securely in browser
use async_trait::async_trait;
use http::StatusCode;
use reqwest::{Certificate, Response};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::ops::Deref;
use std::sync::Arc;
use thiserror::Error;
use warp::reject::Reject;

#[derive(Error, Debug)]
pub enum RelayError {
    #[error("Failure happened during relay")]
    RelayRequestError { source: Option<reqwest::Error> },
}

impl Reject for RelayError {}

#[async_trait]
pub trait Relayer: Clone + Send + Sync {
    async fn relay(&self, path: String, relay_url: String) -> Result<StatusCode, RelayError>;
    async fn get(&self, relay_url: String) -> Result<Response, RelayError>;
}

#[async_trait]
impl<U> Relayer for Arc<U>
where
    U: Relayer,
{
    async fn relay(&self, path: String, relay_url: String) -> Result<StatusCode, RelayError> {
        self.deref().relay(path, relay_url).await
    }

    async fn get(&self, relay_url: String) -> Result<Response, RelayError> {
        self.deref().get(relay_url).await
    }
}

#[derive(Debug, Clone)]
pub struct MutualTLSRelayer {
    pub client: reqwest::Client,
}

impl MutualTLSRelayer {
    pub fn new(
        pem_file_path: String,
        additional_ca_certs: Option<&[Certificate]>,
    ) -> Result<MutualTLSRelayer, RelayError> {
        // Load the client TLS certificate and key
        let mut buf = Vec::new();
        File::open(pem_file_path)
            .unwrap()
            .read_to_end(&mut buf)
            .unwrap();
        let pkcs12 = reqwest::Identity::from_pem(&buf).unwrap();

        // Build the relay HTTP client, adding in provided certificate as additional CA certs
        let mut client_builder = reqwest::Client::builder().identity(pkcs12).use_rustls_tls();
        if let Some(ca_certs) = additional_ca_certs {
            // In order for the additional certs to be used, the built-in root certs must be disabled
            client_builder = client_builder.tls_built_in_root_certs(false);
            for cert in ca_certs.iter() {
                client_builder = client_builder.add_root_certificate(cert.clone())
            }
        }
        Ok(MutualTLSRelayer {
            client: client_builder.build().unwrap(),
        })
    }
}

#[async_trait]
impl Relayer for MutualTLSRelayer {
    async fn relay(&self, path: String, relay_url: String) -> Result<StatusCode, RelayError> {
        let mut req_body = HashMap::new();
        req_body.insert("path", path);
        req_body.insert("userId", "abc".to_owned());

        self.client
            .post(relay_url)
            .json(&req_body)
            .send()
            .await
            .and_then(|response| response.error_for_status())
            .and_then(|response| Ok(response.status()))
            .map_err(|source| RelayError::RelayRequestError {
                source: Some(source),
            })
    }

    async fn get(&self, relay_url: String) -> Result<Response, RelayError> {
        self.client
            .get(relay_url)
            .send()
            .await
            .and_then(|response| response.error_for_status())
            .map_err(|source| RelayError::RelayRequestError {
                source: Some(source),
            })
    }
}

#[cfg(test)]
pub mod tests {
    use super::{RelayError, Relayer};
    use async_trait::async_trait;
    use http::StatusCode;
    use mockall::predicate::*;
    use mockall::*;
    use reqwest::Response;

    mock! {
    pub Relayer {}
    impl Clone for Relayer {
            fn clone(&self) -> Self;
    }

    #[async_trait]
    impl Relayer for MockRelayer {
        async fn relay(&self, path: String, relay_url: String) -> Result<StatusCode, RelayError>;
        async fn get(&self, relay_url: String) -> Result<Response, RelayError>;
    }
    }
}