mod error;
pub mod load;
pub mod tokens;
use std::convert::TryInto;
use std::path::Path;
use std::sync::Arc;
use reqwest::header::{self, HeaderMap};
use reqwest::Client as HttpClient;
use reqwest::{Body, RequestBuilder, StatusCode};
use tokio_stream::{Stream, StreamExt};
use tracing::{debug, info, instrument, trace};
use url::Url;
use crate::provider::{Provider, ProviderError};
use crate::signature::{KeyRing, SignatureRole};
use crate::verification::{Verified, VerifiedInvoice};
use crate::{Id, Invoice, Signed, VerificationStrategy};
pub use error::ClientError;
pub type Result<T> = std::result::Result<T, ClientError>;
pub const INVOICE_ENDPOINT: &str = "_i";
pub const QUERY_ENDPOINT: &str = "_q";
pub const RELATIONSHIP_ENDPOINT: &str = "_r";
pub const LOGIN_ENDPOINT: &str = "login";
pub const BINDLE_KEYS_ENDPOINT: &str = "bindle-keys";
const TOML_MIME_TYPE: &str = "application/toml";
#[derive(Clone)]
pub struct Client<T> {
client: HttpClient,
base_url: Url,
token_manager: T,
verification_strategy: VerificationStrategy,
keyring: Arc<KeyRing>,
}
pub struct ClientBuilder {
http2_prior_knowledge: bool,
danger_accept_invalid_certs: bool,
verification_strategy: VerificationStrategy,
}
impl Default for ClientBuilder {
fn default() -> Self {
ClientBuilder {
http2_prior_knowledge: false,
danger_accept_invalid_certs: false,
verification_strategy: VerificationStrategy::MultipleAttestationGreedy(vec![
SignatureRole::Host,
]),
}
}
}
impl ClientBuilder {
pub fn http2_prior_knowledge(mut self, http2_prior_knowledge: bool) -> Self {
self.http2_prior_knowledge = http2_prior_knowledge;
self
}
pub fn danger_accept_invalid_certs(mut self, danger_accept_invalid_certs: bool) -> Self {
self.danger_accept_invalid_certs = danger_accept_invalid_certs;
self
}
pub fn verification_strategy(mut self, verification_strategy: VerificationStrategy) -> Self {
self.verification_strategy = verification_strategy;
self
}
pub fn build<T>(
self,
base_url: &str,
token_manager: T,
keyring: Arc<KeyRing>,
) -> Result<Client<T>> {
let (base_parsed, headers) = base_url_and_headers(base_url)?;
let client = HttpClient::builder()
.and_if(self.http2_prior_knowledge, |b| b.http2_prior_knowledge())
.and_if(self.danger_accept_invalid_certs, |b| {
b.danger_accept_invalid_certs(true)
})
.default_headers(headers)
.build()
.map_err(|e| ClientError::Other(e.to_string()))?;
Ok(Client {
client,
base_url: base_parsed,
token_manager,
verification_strategy: self.verification_strategy,
keyring,
})
}
}
pub(crate) fn base_url_and_headers(base_url: &str) -> Result<(Url, HeaderMap)> {
let mut base = base_url.to_owned();
if !base.ends_with('/') {
info!("Provided base URL missing trailing slash, adding...");
base.push('/');
}
let base_parsed = Url::parse(&base)?;
let mut headers = header::HeaderMap::new();
headers.insert(header::ACCEPT, TOML_MIME_TYPE.parse().unwrap());
Ok((base_parsed, headers))
}
impl<T: tokens::TokenManager> Client<T> {
pub fn new(base_url: &str, token_manager: T, keyring: Arc<KeyRing>) -> Result<Self> {
ClientBuilder::default().build(base_url, token_manager, keyring)
}
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
#[instrument(level = "trace", skip(self, body))]
pub async fn raw(
&self,
method: reqwest::Method,
path: &str,
body: Option<impl Into<reqwest::Body>>,
) -> Result<reqwest::Response> {
let req = self.client.request(method, self.base_url.join(path)?);
let req = self.token_manager.apply_auth_header(req).await?;
let req = match body {
Some(b) => req.body(b),
None => req,
};
req.send().await.map_err(|e| e.into())
}
#[instrument(level = "trace", skip(self, inv), fields(id = %inv.bindle.id))]
pub async fn create_invoice(
&self,
inv: crate::Invoice,
) -> Result<crate::InvoiceCreateResponse> {
let req = self
.create_invoice_builder()
.await?
.body(toml::to_vec(&inv)?);
self.create_invoice_request(req).await
}
#[instrument(level = "trace", skip(self, file_path), fields(path = %file_path.as_ref().display()))]
pub async fn create_invoice_from_file<P: AsRef<Path>>(
&self,
file_path: P,
) -> Result<crate::InvoiceCreateResponse> {
let path = file_path.as_ref().to_owned();
debug!("Loading invoice from file");
let inv_stream = load::raw(path).await?;
debug!("Successfully loaded invoice stream");
let req = self
.create_invoice_builder()
.await?
.body(Body::wrap_stream(inv_stream));
self.create_invoice_request(req).await
}
async fn create_invoice_builder(&self) -> Result<RequestBuilder> {
let req = self
.client
.post(self.base_url.join(INVOICE_ENDPOINT).unwrap())
.header(header::CONTENT_TYPE, TOML_MIME_TYPE);
self.token_manager.apply_auth_header(req).await
}
async fn create_invoice_request(
&self,
req: RequestBuilder,
) -> Result<crate::InvoiceCreateResponse> {
trace!(?req);
let resp = req.send().await?;
let resp = unwrap_status(resp, Endpoint::Invoice, Operation::Create).await?;
Ok(toml::from_slice(&resp.bytes().await?)?)
}
#[instrument(level = "trace", skip(self, id), fields(invoice_id))]
pub async fn get_invoice<I>(&self, id: I) -> Result<VerifiedInvoice<Invoice>>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
self.get_invoice_request(
self.base_url
.join(&format!("{}/{}", INVOICE_ENDPOINT, parsed_id))?,
)
.await
}
#[instrument(level = "trace", skip(self, id), fields(invoice_id))]
pub async fn get_yanked_invoice<I>(&self, id: I) -> Result<VerifiedInvoice<Invoice>>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
let mut url = self
.base_url
.join(&format!("{}/{}", INVOICE_ENDPOINT, parsed_id))?;
url.set_query(Some("yanked=true"));
self.get_invoice_request(url).await
}
async fn get_invoice_request(&self, url: Url) -> Result<VerifiedInvoice<Invoice>> {
let req = self.client.get(url);
let req = self.token_manager.apply_auth_header(req).await?;
trace!(?req);
let resp = req.send().await?;
let resp = unwrap_status(resp, Endpoint::Invoice, Operation::Get).await?;
let inv: Invoice = toml::from_slice(&resp.bytes().await?)?;
Ok(self.verification_strategy.verify(inv, &self.keyring)?)
}
#[instrument(level = "trace", skip(self))]
pub async fn query_invoices(
&self,
query_opts: crate::QueryOptions,
) -> Result<crate::search::Matches> {
let req = self
.client
.get(self.base_url.join(QUERY_ENDPOINT).unwrap())
.query(&query_opts);
let req = self.token_manager.apply_auth_header(req).await?;
trace!(?req);
let resp = req.send().await?;
let resp = unwrap_status(resp, Endpoint::Query, Operation::Query).await?;
Ok(toml::from_slice(&resp.bytes().await?)?)
}
#[instrument(level = "trace", skip(self, id), fields(invoice_id))]
pub async fn yank_invoice<I>(&self, id: I) -> Result<()>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
let req = self.client.delete(
self.base_url
.join(&format!("{}/{}", INVOICE_ENDPOINT, parsed_id))?,
);
let req = self.token_manager.apply_auth_header(req).await?;
trace!(?req);
let resp = req.send().await?;
unwrap_status(resp, Endpoint::Invoice, Operation::Yank).await?;
Ok(())
}
#[instrument(level = "trace", skip(self, bindle_id, data), fields(invoice_id, data_len = data.len()))]
pub async fn create_parcel<I>(
&self,
bindle_id: I,
parcel_sha: &str,
data: Vec<u8>,
) -> Result<()>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
self.create_parcel_request(
self.create_parcel_builder(&parsed_id, parcel_sha)
.await?
.body(data),
)
.await
}
#[instrument(level = "trace", skip(self, bindle_id, data_path), fields(invoice_id, path = %data_path.as_ref().display()))]
pub async fn create_parcel_from_file<D, I>(
&self,
bindle_id: I,
parcel_sha: &str,
data_path: D,
) -> Result<()>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
D: AsRef<Path>,
{
let data = data_path.as_ref().to_owned();
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
debug!("Loading parcel data from file");
let stream = load::raw(data).await?;
debug!("Successfully loaded parcel stream");
let data_body = Body::wrap_stream(stream);
self.create_parcel_request(
self.create_parcel_builder(&parsed_id, parcel_sha)
.await?
.body(data_body),
)
.await
}
#[instrument(level = "trace", skip(self, bindle_id, stream), fields(invoice_id))]
pub async fn create_parcel_from_stream<I, S, B>(
&self,
bindle_id: I,
parcel_sha: &str,
stream: S,
) -> Result<()>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
S: Stream<Item = std::io::Result<B>> + Unpin + Send + Sync + 'static,
B: bytes::Buf,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
let map = stream.map(|res| res.map(|mut b| b.copy_to_bytes(b.remaining())));
let data_body = Body::wrap_stream(map);
self.create_parcel_request(
self.create_parcel_builder(&parsed_id, parcel_sha)
.await?
.body(data_body),
)
.await
}
async fn create_parcel_builder(
&self,
bindle_id: &Id,
parcel_sha: &str,
) -> Result<RequestBuilder> {
let req = self.client.post(
self.base_url
.join(&format!(
"{}/{}@{}",
INVOICE_ENDPOINT, bindle_id, parcel_sha
))
.unwrap(),
);
self.token_manager.apply_auth_header(req).await
}
async fn create_parcel_request(&self, req: RequestBuilder) -> Result<()> {
trace!(?req);
let resp = req.send().await?;
unwrap_status(resp, Endpoint::Parcel, Operation::Create).await?;
Ok(())
}
#[instrument(level = "trace", skip(self, bindle_id), fields(invoice_id))]
pub async fn get_parcel<I>(&self, bindle_id: I, sha: &str) -> Result<Vec<u8>>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
let resp = self.get_parcel_request(&parsed_id, sha).await?;
Ok(resp.bytes().await?.to_vec())
}
#[instrument(level = "trace", skip(self, bindle_id), fields(invoice_id))]
pub async fn get_parcel_stream<I>(
&self,
bindle_id: I,
sha: &str,
) -> Result<impl Stream<Item = Result<bytes::Bytes>>>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
let resp = self.get_parcel_request(&parsed_id, sha).await?;
Ok(resp.bytes_stream().map(|r| r.map_err(|e| e.into())))
}
async fn get_parcel_request(&self, bindle_id: &Id, sha: &str) -> Result<reqwest::Response> {
let req = self
.client
.get(
self.base_url
.join(&format!("{}/{}@{}", INVOICE_ENDPOINT, bindle_id, sha))
.unwrap(),
)
.header(header::ACCEPT, "*/*");
let req = self.token_manager.apply_auth_header(req).await?;
trace!(?req);
let resp = req.send().await?;
unwrap_status(resp, Endpoint::Parcel, Operation::Get).await
}
#[instrument(level = "trace", skip(self, id), fields(invoice_id))]
pub async fn get_missing_parcels<I>(&self, id: I) -> Result<Vec<crate::Label>>
where
I: TryInto<Id>,
I::Error: Into<ClientError>,
{
let parsed_id = id.try_into().map_err(|e| e.into())?;
tracing::span::Span::current().record("invoice_id", &tracing::field::display(&parsed_id));
let req = self.client.get(self.base_url.join(&format!(
"{}/{}/{}",
RELATIONSHIP_ENDPOINT, "missing", parsed_id
))?);
let req = self.token_manager.apply_auth_header(req).await?;
trace!(?req);
let resp = req.send().await?;
let resp = unwrap_status(resp, Endpoint::Invoice, Operation::Get).await?;
Ok(toml::from_slice::<crate::MissingParcelsResponse>(&resp.bytes().await?)?.missing)
}
#[instrument(level = "trace", skip(self))]
pub async fn get_host_keys(&self) -> Result<KeyRing> {
let resp = self
.raw(reqwest::Method::GET, BINDLE_KEYS_ENDPOINT, None::<&str>)
.await?;
let resp = unwrap_status(resp, Endpoint::BindleKeys, Operation::Get).await?;
Ok(toml::from_slice::<KeyRing>(&resp.bytes().await?)?)
}
}
#[async_trait::async_trait]
impl<T: tokens::TokenManager + Send + Sync + 'static> Provider for Client<T> {
async fn create_invoice<I>(
&self,
invoice: I,
) -> crate::provider::Result<(crate::Invoice, Vec<crate::Label>)>
where
I: Signed + Verified + Send + Sync,
{
let res = self.create_invoice(invoice.signed()).await?;
Ok((res.invoice, res.missing.unwrap_or_default()))
}
async fn get_yanked_invoice<I>(&self, id: I) -> crate::provider::Result<crate::Invoice>
where
I: TryInto<Id> + Send,
I::Error: Into<ProviderError>,
{
let parsed_id = id.try_into().map_err(|e| e.into())?;
self.get_yanked_invoice(parsed_id)
.await
.map_err(|e| e.into())
.map(|inv| inv.into())
}
async fn yank_invoice<I>(&self, id: I) -> crate::provider::Result<()>
where
I: TryInto<Id> + Send,
I::Error: Into<ProviderError>,
{
let parsed_id = id.try_into().map_err(|e| e.into())?;
self.yank_invoice(parsed_id).await.map_err(|e| e.into())
}
async fn create_parcel<I, R, B>(
&self,
bindle_id: I,
parcel_id: &str,
data: R,
) -> crate::provider::Result<()>
where
I: TryInto<Id> + Send,
I::Error: Into<ProviderError>,
R: Stream<Item = std::io::Result<B>> + Unpin + Send + Sync + 'static,
B: bytes::Buf,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
self.create_parcel_from_stream(parsed_id, parcel_id, data)
.await
.map_err(|e| e.into())
}
async fn get_parcel<I>(
&self,
bindle_id: I,
parcel_id: &str,
) -> crate::provider::Result<
Box<dyn Stream<Item = crate::provider::Result<bytes::Bytes>> + Unpin + Send + Sync>,
>
where
I: TryInto<Id> + Send,
I::Error: Into<ProviderError>,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
let stream = self.get_parcel_stream(parsed_id, parcel_id).await?;
Ok(Box::new(stream.map(|res| res.map_err(|e| e.into()))))
}
async fn parcel_exists<I>(&self, bindle_id: I, parcel_id: &str) -> crate::provider::Result<bool>
where
I: TryInto<Id> + Send,
I::Error: Into<ProviderError>,
{
let parsed_id = bindle_id.try_into().map_err(|e| e.into())?;
let resp = self
.raw(
reqwest::Method::HEAD,
&format!(
"{}/{}@{}",
crate::client::INVOICE_ENDPOINT,
parsed_id,
parcel_id,
),
None::<reqwest::Body>,
)
.await
.map_err(|e| ProviderError::Other(e.to_string()))?;
match resp.status() {
StatusCode::OK => Ok(true),
StatusCode::NOT_FOUND => Ok(false),
_ => Err(ProviderError::ProxyError(ClientError::InvalidRequest {
status_code: resp.status(),
message: None,
})),
}
}
}
#[derive(Debug)]
enum Operation {
Create,
Yank,
Get,
Query,
Login,
}
enum Endpoint {
Invoice,
Parcel,
Query,
Login,
BindleKeys,
}
async fn unwrap_status(
resp: reqwest::Response,
endpoint: Endpoint,
operation: Operation,
) -> Result<reqwest::Response> {
match (resp.status(), endpoint) {
(StatusCode::OK, _) => Ok(resp),
(StatusCode::ACCEPTED, Endpoint::Invoice) => Ok(resp),
(StatusCode::CREATED, Endpoint::Invoice) => Ok(resp),
(StatusCode::NOT_FOUND, Endpoint::Invoice) | (StatusCode::FORBIDDEN, Endpoint::Invoice) => {
match operation {
Operation::Get => Err(ClientError::InvoiceNotFound),
_ => Err(ClientError::ResourceNotFound),
}
}
(StatusCode::NOT_FOUND, Endpoint::Parcel) => match operation {
Operation::Get => Err(ClientError::ParcelNotFound),
_ => Err(ClientError::ResourceNotFound),
},
(StatusCode::CONFLICT, Endpoint::Invoice) => Err(ClientError::InvoiceAlreadyExists),
(StatusCode::CONFLICT, Endpoint::Parcel) => Err(ClientError::ParcelAlreadyExists),
(StatusCode::UNAUTHORIZED, _) => Err(ClientError::Unauthorized),
(StatusCode::BAD_REQUEST, Endpoint::BindleKeys) => Err(ClientError::InvalidRequest {
status_code: resp.status(),
message: parse_error_from_body(resp).await,
}),
(StatusCode::BAD_REQUEST, _) => Err(ClientError::ServerError(Some(format!(
"Bad request: {}",
parse_error_from_body(resp).await.unwrap_or_default()
)))),
(_, _) if resp.status().is_server_error() => {
Err(ClientError::ServerError(parse_error_from_body(resp).await))
}
(_, _) if resp.status().is_client_error() => Err(ClientError::InvalidRequest {
status_code: resp.status(),
message: parse_error_from_body(resp).await,
}),
_ => Err(ClientError::Other(format!(
"Unknown error response: {:?} to {} returned status {}: {}",
operation,
resp.url().to_owned(),
resp.status(),
parse_error_from_body(resp)
.await
.unwrap_or_else(|| "(no error message in response)".to_owned())
))),
}
}
async fn parse_error_from_body(resp: reqwest::Response) -> Option<String> {
let bytes = match resp.bytes().await {
Ok(b) => b,
Err(_) => return None,
};
match toml::from_slice::<crate::ErrorResponse>(&bytes) {
Ok(e) => Some(e.error),
Err(_) => None,
}
}
trait ConditionalBuilder {
fn and_if(self, condition: bool, build_method: impl Fn(Self) -> Self) -> Self
where
Self: Sized,
{
if condition {
build_method(self)
} else {
self
}
}
}
impl ConditionalBuilder for reqwest::ClientBuilder {}