use std::{
collections::BTreeMap,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{Service, ServiceExt};
use crate::error::TransportError;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SdkRequest {
pub body: Option<Vec<u8>>,
pub requires_auth: bool,
pub headers: BTreeMap<String, String>,
pub method: String,
pub path: String,
}
impl SdkRequest {
#[must_use]
pub fn builder(
method: impl Into<String>,
path: impl Into<String>,
) -> SdkRequestBuilder<MethodSet, PathSet> {
SdkRequestBuilder {
method: method.into(),
path: path.into(),
body: None,
requires_auth: false,
headers: BTreeMap::new(),
_state: std::marker::PhantomData,
}
}
#[must_use]
pub fn new(method: impl Into<String>, path: impl Into<String>) -> Self {
Self {
body: None,
requires_auth: false,
headers: BTreeMap::new(),
method: method.into(),
path: path.into(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MethodSet;
#[derive(Debug, Clone, Copy)]
pub struct PathSet;
#[derive(Debug)]
pub struct SdkRequestBuilder<M, P> {
method: String,
path: String,
body: Option<Vec<u8>>,
requires_auth: bool,
headers: BTreeMap<String, String>,
_state: std::marker::PhantomData<(M, P)>,
}
impl SdkRequestBuilder<MethodSet, PathSet> {
#[must_use]
pub fn build(self) -> SdkRequest {
SdkRequest {
method: self.method,
path: self.path,
body: self.body,
requires_auth: self.requires_auth,
headers: self.headers,
}
}
}
impl<M, P> SdkRequestBuilder<M, P> {
#[must_use]
pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
self.body = Some(body.into());
self
}
#[must_use]
pub const fn auth_required(mut self, required: bool) -> Self {
self.requires_auth = required;
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
#[must_use]
pub fn headers(mut self, headers: BTreeMap<String, String>) -> Self {
self.headers.extend(headers);
self
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SdkResponse {
pub body: Vec<u8>,
pub headers: BTreeMap<String, String>,
pub status: u16,
}
pub trait Transport:
Service<SdkRequest, Response = SdkResponse, Error = TransportError> + Send + Sync
{
}
impl<T> Transport for T where
T: Service<SdkRequest, Response = SdkResponse, Error = TransportError> + Send + Sync
{
}
pub trait TransportExt: Transport {
fn execute(
&mut self,
request: SdkRequest,
) -> impl Future<Output = Result<SdkResponse, TransportError>> + Send;
}
impl<T> TransportExt for T
where
T: Transport + Clone,
<T as Service<SdkRequest>>::Future: Send,
{
async fn execute(&mut self, request: SdkRequest) -> Result<SdkResponse, TransportError> {
let transport = self.clone();
transport.oneshot(request).await
}
}
#[derive(Clone)]
pub struct HpxTransport {
client: hpx::Client,
}
impl std::fmt::Debug for HpxTransport {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.debug_struct("HpxTransport").finish_non_exhaustive()
}
}
impl Default for HpxTransport {
fn default() -> Self {
Self::new(hpx::Client::new())
}
}
impl HpxTransport {
#[must_use]
pub const fn new(client: hpx::Client) -> Self {
Self { client }
}
}
impl Service<SdkRequest> for HpxTransport {
type Response = SdkResponse;
type Error = TransportError;
type Future = Pin<Box<dyn Future<Output = Result<SdkResponse, TransportError>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: SdkRequest) -> Self::Future {
let client = self.client.clone();
Box::pin(async move {
let method = hpx::Method::from_bytes(request.method.as_bytes())
.map_err(|_| TransportError::InvalidMethod { method: request.method.clone() })?;
let mut builder = client.request(method, request.path);
for (name, value) in request.headers {
builder = builder.header(name, value);
}
if let Some(body) = request.body {
builder = builder.body(body);
}
let response = builder.send().await?;
let status = response.status().as_u16();
let headers = response
.headers()
.iter()
.filter_map(|(name, value)| {
value.to_str().ok().map(|value| (name.to_string(), value.to_string()))
})
.collect::<BTreeMap<_, _>>();
let body = response.bytes().await?.to_vec();
Ok(SdkResponse { body, headers, status })
})
}
}