use http_body_util::Full;
use hyper::body::Body;
use hyper::{body::Bytes, Request, Response};
use std::future::Future;
pub trait MakeRequest: Sized {
type Body: Body;
type Error: std::error::Error + Send + Sync + 'static;
fn request(
&self,
request: Request<Full<Bytes>>,
) -> impl Future<Output = std::result::Result<Response<Self::Body>, Self::Error>>;
}
#[cfg(feature = "default-client")]
mod default_impl {
use super::MakeRequest;
use cookie::time::OffsetDateTime;
use cookie::{Cookie, CookieJar};
use http_body_util::Full;
use hyper::client::conn::http2::Builder as Http2Builder;
use hyper::{
body::{Bytes, Incoming},
client::conn::http2::SendRequest,
Request, Response, StatusCode,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use rustls::ClientConfig;
use std::{
error::Error,
fmt::{Debug, Display},
sync::Arc,
time::Duration,
};
use tokio::{
net::TcpStream,
sync::{Mutex, RwLock},
};
use tokio_rustls::TlsConnector;
struct Http2Only {
force_ipv4: bool,
config: Arc<ClientConfig>,
send: tokio::sync::Mutex<Option<SendRequest<Full<Bytes>>>>,
}
impl Http2Only {
async fn make_connection(&self) -> Result<SendRequest<Full<Bytes>>, DefaultTransportError> {
let host = if self.force_ipv4 {
"api-ipv4.porkbun.com"
} else {
"api.porkbun.com"
};
#[cfg(feature = "tracing")]
tracing::debug!(target: "porkbun_api::transport", "connecting to {}", host);
let arc_config = self.config.clone();
let server_name = host.try_into().unwrap();
let tokio_tls_connecto = TlsConnector::from(arc_config);
let tcp = TcpStream::connect(if self.force_ipv4 {
"api-ipv4.porkbun.com:443"
} else {
"api.porkbun.com:443"
})
.await
.map_err(DefaultTransportErrorImpl::ConnectionError)?;
let connection = tokio_tls_connecto
.connect(server_name, tcp)
.await
.map_err(DefaultTransportErrorImpl::ConnectionError)?;
let hyper_io = TokioIo::new(connection);
let (send, conn) = Http2Builder::new(TokioExecutor::new())
.handshake(hyper_io)
.await?;
#[cfg(feature = "tracing")]
tracing::debug!(target: "porkbun_api::transport", "connection established");
tokio::spawn(conn);
Ok(send)
}
pub fn new(force_ipv4: bool) -> Self {
use rustls_platform_verifier::BuilderVerifierExt;
let mut config = rustls::ClientConfig::builder()
.with_platform_verifier()
.expect("Failed to create platform verifier")
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".into()];
let config = Arc::new(config);
Self {
force_ipv4,
config,
send: Mutex::new(None),
}
}
}
impl Default for Http2Only {
fn default() -> Self {
Self::new(false)
}
}
impl MakeRequest for Http2Only {
type Body = Incoming;
type Error = DefaultTransportError;
async fn request(
&self,
request: Request<Full<Bytes>>,
) -> Result<Response<Self::Body>, Self::Error> {
let mut lock = self.send.lock().await;
if lock.is_none() || lock.as_ref().is_some_and(|l| l.is_closed()) {
#[cfg(feature = "tracing")]
tracing::debug!(target: "porkbun_api::transport", "connection closed or not established, reconnecting");
let conn = self.make_connection().await?;
*lock = Some(conn)
}
let sender = lock.as_mut().unwrap();
sender.ready().await?;
sender
.send_request(request)
.await
.map_err(DefaultTransportError::from)
}
}
#[derive(Clone)]
struct Retry502<T: MakeRequest> {
inner: T,
}
impl<T: MakeRequest> Retry502<T> {
fn wrapping(inner: T) -> Self {
Self { inner }
}
}
impl<E, T: MakeRequest<Error = E>> MakeRequest for Retry502<T>
where
DefaultTransportError: From<E>,
{
type Body = T::Body;
type Error = DefaultTransportError;
async fn request(
&self,
request: Request<Full<Bytes>>,
) -> Result<Response<Self::Body>, Self::Error> {
let sleep_time = Duration::from_millis(250);
let max_sleep = 10;
let mut slept = 0;
let resp = loop {
let resp = self.inner.request(request.clone()).await?;
if resp.status() != StatusCode::SERVICE_UNAVAILABLE {
break resp;
} else if slept >= max_sleep {
#[cfg(feature = "tracing")]
tracing::warn!(target: "porkbun_api::transport", "retry limit reached after {} attempts", max_sleep);
return Err(DefaultTransportError(DefaultTransportErrorImpl::RetryError));
} else {
slept += 1;
#[cfg(feature = "tracing")]
tracing::info!(target: "porkbun_api::transport", "received 502, retrying (attempt {}/{}).", slept, max_sleep);
tokio::time::sleep(sleep_time).await
}
};
Ok(resp)
}
}
pub struct TrackCookies<T> {
inner: T,
cookie_jar: RwLock<CookieJar>,
}
impl<T> TrackCookies<T> {
pub fn wrapping(inner: T) -> Self {
Self {
inner,
cookie_jar: RwLock::new(CookieJar::new()),
}
}
fn is_cookie_valid_for_request(cookie: &Cookie, request: &Request<Full<Bytes>>) -> bool {
if let Some(domain) = cookie.domain() {
if !request.uri().host().unwrap_or("").ends_with(domain) {
return false;
}
}
if let Some(path) = cookie.path() {
if !request.uri().path().starts_with(path) {
return false;
}
}
if let Some(expires) = cookie.expires_datetime() {
if expires <= OffsetDateTime::now_utc() {
return false;
}
}
true
}
}
impl<T: MakeRequest> MakeRequest for TrackCookies<T> {
type Body = T::Body;
type Error = T::Error;
async fn request(
&self,
mut request: Request<Full<Bytes>>,
) -> Result<Response<T::Body>, T::Error> {
let cookie_header = {
let jar = self.cookie_jar.read().await;
jar.iter()
.filter(|cookie| Self::is_cookie_valid_for_request(cookie, &request))
.map(|c| {
let (name, value) = c.name_value_trimmed();
format!("{name}={value}")
})
.collect::<Vec<_>>()
.join("; ")
};
if !cookie_header.is_empty() {
#[cfg(feature = "tracing")]
tracing::trace!(target: "porkbun_api::transport", "added {} cookies to request", cookie_header.split("; ").count());
request
.headers_mut()
.insert(hyper::header::COOKIE, cookie_header.parse().unwrap());
}
let response = self.inner.request(request).await?;
let cookies = response
.headers()
.get_all(hyper::header::SET_COOKIE)
.iter()
.filter_map(|h| h.to_str().ok())
.filter_map(|s| Cookie::parse(s).ok())
.collect::<Vec<_>>();
if !cookies.is_empty() {
#[cfg(feature = "tracing")]
tracing::trace!(target: "porkbun_api::transport", "stored {} cookies from response", cookies.len());
let mut jar = self.cookie_jar.write().await;
for cookie in cookies {
jar.add(cookie.into_owned());
}
}
Ok(response)
}
}
pub struct DefaultTransport(Retry502<TrackCookies<Http2Only>>);
impl Default for DefaultTransport {
fn default() -> Self {
Self(Retry502::wrapping(TrackCookies::wrapping(
Http2Only::default(),
)))
}
}
impl DefaultTransport {
pub fn new(force_ipv4: bool) -> Self {
Self(Retry502::wrapping(TrackCookies::wrapping(Http2Only::new(
force_ipv4,
))))
}
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug)]
enum DefaultTransportErrorImpl {
ConnectionError(std::io::Error),
RetryError,
HttpError(hyper::Error),
}
impl From<hyper::Error> for DefaultTransportErrorImpl {
fn from(value: hyper::Error) -> Self {
Self::HttpError(value)
}
}
pub struct DefaultTransportError(DefaultTransportErrorImpl);
impl Debug for DefaultTransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.0, f)
}
}
impl<T> From<T> for DefaultTransportError
where
T: Into<DefaultTransportErrorImpl>,
{
fn from(value: T) -> Self {
Self(value.into())
}
}
impl Error for DefaultTransportError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.0 {
DefaultTransportErrorImpl::ConnectionError(e) => Some(e),
DefaultTransportErrorImpl::HttpError(e) => Some(e),
DefaultTransportErrorImpl::RetryError => None,
}
}
}
impl Display for DefaultTransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self.0 {
DefaultTransportErrorImpl::ConnectionError(_) => "Failed to connect to endpoint",
DefaultTransportErrorImpl::HttpError(_) => "HTTP protocol error",
DefaultTransportErrorImpl::RetryError => {
"Server took to many tries to reply with a non-502 statuscode"
}
})
}
}
impl MakeRequest for DefaultTransport {
type Body = Incoming;
type Error = DefaultTransportError;
async fn request(
&self,
request: Request<Full<Bytes>>,
) -> Result<Response<Self::Body>, Self::Error> {
self.0.request(request).await
}
}
}
#[cfg(feature = "default-client")]
pub use default_impl::*;