use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
mod accept;
pub mod body;
pub mod compat;
mod error;
pub mod request;
mod response;
pub mod security;
mod server;
pub use accept::WithAccept;
pub use error::{BoxError, EncodeRequestError};
pub use request::Request;
pub use response::Response;
pub use security::{AuthSelector, NoAuth, OperationSecurity, SecurityCredential};
pub use server::{Server, WithServer};
pub trait MakeRequest {
type Error: std::error::Error + Send + Sync + 'static;
fn make_request(self) -> impl Future<Output = Result<Request, Self::Error>> + Send;
}
pub trait ParseResponse: Sized {
type Error: std::error::Error;
fn parse_response<B>(
response: ::http::Response<B>,
) -> impl Future<Output = Result<Self, Self::Error>> + Send
where
B: http_body::Body<Data = ::bytes::Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>;
}
#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
#[error("unexpected HTTP status: {0}")]
UnexpectedStatus(http::StatusCode),
#[error("codec error: {0}")]
Codec(#[source] BoxError),
}
pub trait Operation: MakeRequest {
type Response;
}
#[derive(Debug, thiserror::Error)]
pub enum CallError<TransportError> {
#[error("encode error: {0}")]
Encode(#[source] BoxError),
#[error("auth error: {0}")]
Auth(#[source] BoxError),
#[error("transport error: {0}")]
Transport(#[source] TransportError),
#[error("decode error: {0}")]
Decode(#[source] DecodeError),
}
#[derive(Clone)]
pub struct ApiClient<S> {
inner: S,
base_url: Arc<str>,
auth: Arc<dyn AuthSelector>,
}
impl<S: std::fmt::Debug> std::fmt::Debug for ApiClient<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApiClient")
.field("inner", &self.inner)
.field("base_url", &self.base_url)
.field("auth", &"<dyn AuthSelector>")
.finish()
}
}
impl<S> ApiClient<S> {
pub fn new<Srv: Server>(inner: S, server: Srv) -> Self {
Self {
inner,
base_url: Arc::from(server.base_url().as_ref()),
auth: security::default_auth(),
}
}
#[must_use]
pub fn with_auth<A: AuthSelector>(mut self, auth: A) -> Self {
self.auth = Arc::new(auth);
self
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn into_inner(self) -> S {
self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S, Op, B> tower::Service<Op> for ApiClient<S>
where
Op: Operation + Send + 'static,
Op::Error: Into<BoxError> + Send + 'static,
Op::Response: ParseResponse + Send + 'static,
<Op::Response as ParseResponse>::Error: Into<DecodeError> + Send + 'static,
S: tower::Service<Request, Response = ::http::Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
B: http_body::Body<Data = ::bytes::Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
type Response = Op::Response;
type Error = CallError<S::Error>;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(CallError::Transport)
}
fn call(&mut self, op: Op) -> Self::Future {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
let base_url = self.base_url.clone();
let auth = self.auth.clone();
Box::pin(async move {
let http_req = op
.make_request()
.await
.map_err(|e| CallError::Encode(e.into()))?;
let http_req = server::prefix_base_url(http_req, &base_url);
let requirements = http_req
.extensions()
.get::<OperationSecurity>()
.copied()
.unwrap_or(OperationSecurity::PUBLIC);
let http_req = auth
.apply_for(http_req, requirements.0)
.await
.map_err(CallError::Auth)?;
tracing::info!(uri= ?http_req.uri(), method = ?http_req.method(), headers = ?http_req.headers(), body = ?http_req.body(), "request");
let http_resp = inner.call(http_req).await.map_err(CallError::Transport)?;
Op::Response::parse_response(http_resp)
.await
.map_err(|e| CallError::Decode(e.into()))
})
}
}
#[cfg(feature = "base64")]
#[derive(Clone, PartialEq, Eq, Hash, Default)]
pub struct Base64String(::bytes::Bytes);
#[cfg(feature = "base64")]
impl Base64String {
pub fn from_bytes(bytes: impl Into<::bytes::Bytes>) -> Self {
Self(bytes.into())
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_ref()
}
pub fn into_bytes(self) -> ::bytes::Bytes {
self.0
}
pub fn decode(encoded: &str) -> Result<Self, ::base64::DecodeError> {
use ::base64::Engine as _;
let bytes = ::base64::engine::general_purpose::STANDARD.decode(encoded)?;
Ok(Self(::bytes::Bytes::from(bytes)))
}
pub fn encode(&self) -> String {
use ::base64::Engine as _;
::base64::engine::general_purpose::STANDARD.encode(self.0.as_ref())
}
}
#[cfg(feature = "base64")]
impl std::fmt::Display for Base64String {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.encode())
}
}
#[cfg(feature = "base64")]
impl std::fmt::Debug for Base64String {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Base64String").field(&self.encode()).finish()
}
}
#[cfg(feature = "base64")]
impl From<::bytes::Bytes> for Base64String {
fn from(value: ::bytes::Bytes) -> Self {
Self(value)
}
}
#[cfg(feature = "base64")]
impl From<Vec<u8>> for Base64String {
fn from(value: Vec<u8>) -> Self {
Self(::bytes::Bytes::from(value))
}
}
#[cfg(feature = "base64")]
impl AsRef<[u8]> for Base64String {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
#[cfg(feature = "base64")]
impl serde::Serialize for Base64String {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.encode())
}
}
#[cfg(feature = "base64")]
impl<'de> serde::Deserialize<'de> for Base64String {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error as _;
let encoded =
<std::borrow::Cow<'de, str> as serde::Deserialize>::deserialize(deserializer)?;
Self::decode(&encoded).map_err(D::Error::custom)
}
}
#[macro_export]
macro_rules! include_client {
($stem:literal) => {
include!(concat!(env!("OUT_DIR"), "/", $stem, ".rs"));
};
}
#[cfg(all(test, feature = "base64"))]
mod base64_tests {
use super::Base64String;
#[test]
fn json_round_trip() {
let original = Base64String::from_bytes(b"hello".as_slice().to_vec());
let json = serde_json::to_string(&original).unwrap();
assert_eq!(json, "\"aGVsbG8=\"");
let decoded: Base64String = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.as_bytes(), b"hello");
}
#[test]
fn display_emits_base64() {
let v = Base64String::from_bytes(vec![0x00, 0xff, 0x10]);
assert_eq!(v.to_string(), "AP8Q");
}
#[test]
fn invalid_base64_errors_on_deserialize() {
let err = serde_json::from_str::<Base64String>("\"not base64!\"").unwrap_err();
assert!(
err.to_string().contains("Invalid")
|| err.to_string().to_lowercase().contains("base64")
);
}
}