use reqwest::dns::Resolve;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::io::Read;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use uuid::Uuid;
use crate::{Error, Fact};
static APP_USER_AGENT: &str = "Oso Cloud (rust)";
const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ApiError {
message: Option<String>,
}
#[derive(Clone)]
pub(crate) struct Client {
client: reqwest::Client,
pub(crate) url: Arc<String>,
last_offset: Arc<RwLock<Option<String>>>,
}
#[derive(Clone)]
pub struct ConnectOptions<R: Resolve + 'static> {
pub dns_resolver: Option<Arc<R>>,
pub ca_path: Option<String>,
}
pub(crate) struct ClientBuilder {
client_builder: reqwest::ClientBuilder,
url: Arc<String>,
}
impl ClientBuilder {
pub(crate) fn new(url: &str, api_key: &str) -> Result<Self, Error> {
let mut headers = HeaderMap::new();
let mut auth_value = HeaderValue::from_str(&format!("Bearer {api_key}"))
.map_err(|e| Error::Input(format!("invalid auth token: {e}")))?;
auth_value.set_sensitive(true);
headers.insert(AUTHORIZATION, auth_value);
headers.insert("X-Oso-Client-Id", HeaderValue::from_static("rust"));
headers.insert("Accept", HeaderValue::from_static("application/json"));
headers.insert(
"X-Oso-Instance-Id",
HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
);
let client_builder = reqwest::Client::builder()
.user_agent(APP_USER_AGENT)
.default_headers(headers)
.http2_keep_alive_while_idle(true)
.http2_keep_alive_interval(Duration::from_secs(30))
.http2_keep_alive_timeout(Duration::from_secs(1));
Ok(Self {
client_builder,
url: Arc::new(url.to_string()),
})
}
pub fn dns_resolver<R: Resolve + 'static>(mut self, resolver: Arc<R>) -> ClientBuilder {
self.client_builder = self.client_builder.dns_resolver(resolver);
self
}
pub fn ca_path(mut self, ca_path: &str) -> Result<ClientBuilder, Error> {
let mut buf = Vec::new();
std::fs::File::open(ca_path)
.map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?
.read_to_end(&mut buf)
.map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?;
let cert = reqwest::Certificate::from_pem(&buf)?;
self.client_builder = self.client_builder.add_root_certificate(cert);
Ok(self)
}
pub fn build(self) -> Result<Client, Error> {
let client = self.client_builder.build()?;
Ok(Client {
client,
url: self.url.clone(),
last_offset: Default::default(),
})
}
}
impl Client {
async fn handle_error<T>(response: reqwest::Response) -> Result<T, Error>
where
T: DeserializeOwned,
{
if !response.status().is_success() {
let status = response.status();
let request_id = response
.headers()
.get("X-Request-ID")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let message = match response.json::<ApiError>().await {
Ok(err) => err.message.unwrap_or_else(|| status.to_string()),
Err(err) => {
tracing::warn!("failed to parse error response: {:#?}", err);
status.to_string()
}
};
return Err(Error::Server { message, request_id });
}
Ok(response.json().await?)
}
fn set_last_offset(&self, response: &reqwest::Response) {
let offset = response.headers().get("OsoOffset").and_then(|h| h.to_str().ok());
if let Some(offset) = offset {
*self.last_offset.write().unwrap() = Some(offset.to_string());
}
}
#[tracing::instrument(skip(self), level = "trace", err)]
pub async fn get<Params, Response>(&self, path: &str, params: Params) -> Result<Response, Error>
where
Params: std::fmt::Debug + Serialize,
Response: DeserializeOwned,
{
let url = format!("{}/api/{path}", self.url, path = path);
let mut request = self.client.get(url).query(¶ms);
if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
request = request.header("OsoOffset", offset);
}
request = request.header(
"X-Request-ID",
HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
);
let response = request.send().await?;
Self::handle_error(response).await
}
#[tracing::instrument(skip(self), level = "trace", err)]
pub async fn post<Body, Response>(&self, path: &str, body: Body, is_mutation: bool) -> Result<Response, Error>
where
Body: std::fmt::Debug + Serialize,
Response: DeserializeOwned,
{
let url = format!("{}/api/{path}", self.url);
let body_vec = serde_json::to_vec(&body).unwrap();
if body_vec.len() > MAX_BODY_SIZE {
return Err(Error::Input("Request payload too large".to_owned()));
}
let mut request = self.client.post(url).json(&body);
if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
request = request.header("OsoOffset", offset);
}
request = request.header(
"X-Request-ID",
HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
);
let response = request.send().await?;
if is_mutation {
self.set_last_offset(&response);
}
Self::handle_error(response).await
}
pub async fn bulk(&self, delete: &[Fact<'_>], tell: &[Fact<'_>]) -> Result<(), Error> {
#[derive(Debug, Serialize)]
struct BulkRequest<'a> {
delete: &'a [Fact<'a>],
tell: &'a [Fact<'a>],
}
let _: crate::ApiResult = self.post("bulk", BulkRequest { delete, tell }, true).await?;
Ok(())
}
}