oso-cloud 0.5.4

Oso Cloud client
Documentation
use reqwest::dns::Resolve;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::de::DeserializeOwned;
/// Internal API functionality not intended for public use.
use serde::{Deserialize, Serialize};
use std::io::Read;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use uuid::Uuid;

use crate::{Error, Fact};

static APP_USER_AGENT: &str = "Oso Cloud (rust)";

const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; // 10 MB

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ApiError {
    message: Option<String>,
}

#[derive(Clone)]
pub(crate) struct Client {
    client: reqwest::Client,
    pub(crate) url: Arc<String>,
    last_offset: Arc<RwLock<Option<String>>>,
}

/// Options for connecting to the Oso Service
#[derive(Clone)]
pub struct ConnectOptions<R: Resolve + 'static> {
    /// A custom DNS resolver. Can be useful for adding custom DNS servers.
    /// Must implement the `reqwest::dns::Resolve` trait.
    pub dns_resolver: Option<Arc<R>>,

    /// Path to another root CA certificate to trust when doing certificate
    /// validation. Useful for trusting certificates signed with an internal
    /// certificate authority.
    pub ca_path: Option<String>,
}

pub(crate) struct ClientBuilder {
    client_builder: reqwest::ClientBuilder,
    url: Arc<String>,
}

impl ClientBuilder {
    pub(crate) fn new(url: &str, api_key: &str) -> Result<Self, Error> {
        let mut headers = HeaderMap::new();
        let mut auth_value = HeaderValue::from_str(&format!("Bearer {api_key}"))
            .map_err(|e| Error::Input(format!("invalid auth token: {e}")))?;
        auth_value.set_sensitive(true);
        headers.insert(AUTHORIZATION, auth_value);
        headers.insert("X-Oso-Client-Id", HeaderValue::from_static("rust"));
        headers.insert("Accept", HeaderValue::from_static("application/json"));
        headers.insert(
            "X-Oso-Instance-Id",
            HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
        );
        let client_builder = reqwest::Client::builder()
            .user_agent(APP_USER_AGENT)
            .default_headers(headers)
            .http2_keep_alive_while_idle(true)
            .http2_keep_alive_interval(Duration::from_secs(30))
            .http2_keep_alive_timeout(Duration::from_secs(1));

        Ok(Self {
            client_builder,
            url: Arc::new(url.to_string()),
        })
    }

    /// Override the DNS resolver implementation.
    pub fn dns_resolver<R: Resolve + 'static>(mut self, resolver: Arc<R>) -> ClientBuilder {
        self.client_builder = self.client_builder.dns_resolver(resolver);
        self
    }

    /// Add another CA certificate to trust
    pub fn ca_path(mut self, ca_path: &str) -> Result<ClientBuilder, Error> {
        let mut buf = Vec::new();
        std::fs::File::open(ca_path)
            .map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?
            .read_to_end(&mut buf)
            .map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?;
        let cert = reqwest::Certificate::from_pem(&buf)?;
        self.client_builder = self.client_builder.add_root_certificate(cert);
        Ok(self)
    }

    pub fn build(self) -> Result<Client, Error> {
        let client = self.client_builder.build()?;
        Ok(Client {
            client,
            url: self.url.clone(),
            last_offset: Default::default(),
        })
    }
}

impl Client {
    async fn handle_error<T>(response: reqwest::Response) -> Result<T, Error>
    where
        T: DeserializeOwned,
    {
        if !response.status().is_success() {
            let status = response.status();
            let request_id = response
                .headers()
                .get("X-Request-ID")
                .and_then(|h| h.to_str().ok())
                .map(|s| s.to_string());
            let message = match response.json::<ApiError>().await {
                Ok(err) => err.message.unwrap_or_else(|| status.to_string()),
                Err(err) => {
                    tracing::warn!("failed to parse error response: {:#?}", err);
                    status.to_string()
                }
            };
            return Err(Error::Server { message, request_id });
        }

        Ok(response.json().await?)
    }

    fn set_last_offset(&self, response: &reqwest::Response) {
        let offset = response.headers().get("OsoOffset").and_then(|h| h.to_str().ok());
        if let Some(offset) = offset {
            *self.last_offset.write().unwrap() = Some(offset.to_string());
        }
    }

    #[tracing::instrument(skip(self), level = "trace", err)]
    pub async fn get<Params, Response>(&self, path: &str, params: Params) -> Result<Response, Error>
    where
        Params: std::fmt::Debug + Serialize,
        Response: DeserializeOwned,
    {
        let url = format!("{}/api/{path}", self.url, path = path);
        let mut request = self.client.get(url).query(&params);

        if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
            request = request.header("OsoOffset", offset);
        }
        request = request.header(
            "X-Request-ID",
            HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
        );
        let response = request.send().await?;
        Self::handle_error(response).await
    }

    #[tracing::instrument(skip(self), level = "trace", err)]
    pub async fn post<Body, Response>(&self, path: &str, body: Body, is_mutation: bool) -> Result<Response, Error>
    where
        Body: std::fmt::Debug + Serialize,
        Response: DeserializeOwned,
    {
        let url = format!("{}/api/{path}", self.url);

        let body_vec = serde_json::to_vec(&body).unwrap();
        if body_vec.len() > MAX_BODY_SIZE {
            return Err(Error::Input("Request payload too large".to_owned()));
        }

        let mut request = self.client.post(url).json(&body);
        if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
            request = request.header("OsoOffset", offset);
        }
        request = request.header(
            "X-Request-ID",
            HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
        );
        let response = request.send().await?;
        if is_mutation {
            self.set_last_offset(&response);
        }
        Self::handle_error(response).await
    }

    pub async fn bulk(&self, delete: &[Fact<'_>], tell: &[Fact<'_>]) -> Result<(), Error> {
        #[derive(Debug, Serialize)]
        struct BulkRequest<'a> {
            delete: &'a [Fact<'a>],
            tell: &'a [Fact<'a>],
        }

        let _: crate::ApiResult = self.post("bulk", BulkRequest { delete, tell }, true).await?;
        Ok(())
    }
}