1use http_body_util::Full;
3use hyper::body::Body;
4use hyper::{body::Bytes, Request, Response};
5
6use std::future::Future;
7
8pub trait MakeRequest: Sized {
10 type Body: Body;
14 type Error: std::error::Error + Send + Sync + 'static;
16 fn request(
18 &self,
19 request: Request<Full<Bytes>>,
20 ) -> impl Future<Output = std::result::Result<Response<Self::Body>, Self::Error>>;
21}
22
23#[cfg(feature = "default-client")]
24mod default_impl {
25 use super::MakeRequest;
26 use cookie::time::OffsetDateTime;
27 use cookie::{Cookie, CookieJar};
28 use http_body_util::Full;
29 use hyper::client::conn::http2::Builder as Http2Builder;
30 use hyper::{
31 body::{Bytes, Incoming},
32 client::conn::http2::SendRequest,
33 Request, Response, StatusCode,
34 };
35 use hyper_util::rt::{TokioExecutor, TokioIo};
36 use rustls::ClientConfig;
37 use std::{
38 error::Error,
39 fmt::{Debug, Display},
40 sync::Arc,
41 time::Duration,
42 };
43 use tokio::{
44 net::TcpStream,
45 sync::{Mutex, RwLock},
46 };
47 use tokio_rustls::TlsConnector;
48
49 struct Http2Only {
50 force_ipv4: bool,
51 config: Arc<ClientConfig>,
52 send: tokio::sync::Mutex<Option<SendRequest<Full<Bytes>>>>,
53 }
54
55 impl Http2Only {
56 async fn make_connection(&self) -> Result<SendRequest<Full<Bytes>>, DefaultTransportError> {
57 let arc_config = self.config.clone();
58 let server_name = if self.force_ipv4 {
59 "api-ipv4.porkbun.com"
60 } else {
61 "api.porkbun.com"
62 }
63 .try_into()
64 .unwrap();
65 let tokio_tls_connecto = TlsConnector::from(arc_config);
66 let tcp = TcpStream::connect(if self.force_ipv4 {
67 "api-ipv4.porkbun.com:443"
68 } else {
69 "api.porkbun.com:443"
70 })
71 .await
72 .map_err(DefaultTransportErrorImpl::ConnectionError)?;
73 let connection = tokio_tls_connecto
74 .connect(server_name, tcp)
75 .await
76 .map_err(DefaultTransportErrorImpl::ConnectionError)?;
77 let hyper_io = TokioIo::new(connection);
78
79 let (send, conn) = Http2Builder::new(TokioExecutor::new())
80 .handshake(hyper_io)
81 .await?;
82 tokio::spawn(conn);
83 Ok(send)
84 }
85 pub fn new(force_ipv4: bool) -> Self {
86 use rustls_platform_verifier::BuilderVerifierExt;
87
88 let mut config = rustls::ClientConfig::builder()
89 .with_platform_verifier()
90 .expect("Failed to create platform verifier")
91 .with_no_client_auth();
92 config.alpn_protocols = vec![b"h2".into()];
93 let config = Arc::new(config);
94
95 Self {
96 force_ipv4,
97 config,
98 send: Mutex::new(None),
99 }
100 }
101 }
102
103 impl Default for Http2Only {
104 fn default() -> Self {
105 Self::new(false)
106 }
107 }
108
109 impl MakeRequest for Http2Only {
110 type Body = Incoming;
111 type Error = DefaultTransportError;
112 async fn request(
113 &self,
114 request: Request<Full<Bytes>>,
115 ) -> Result<Response<Self::Body>, Self::Error> {
116 let mut lock = self.send.lock().await;
117 if lock.is_none() || lock.as_ref().is_some_and(|l| l.is_closed()) {
118 let conn = self.make_connection().await?;
119 *lock = Some(conn)
120 }
121 let sender = lock.as_mut().unwrap();
122 sender.ready().await?;
123 sender
124 .send_request(request)
125 .await
126 .map_err(DefaultTransportError::from)
127 }
128 }
129
130 #[derive(Clone)]
131 struct Retry502<T: MakeRequest> {
132 inner: T,
133 }
134
135 impl<T: MakeRequest> Retry502<T> {
136 fn wrapping(inner: T) -> Self {
137 Self { inner }
138 }
139 }
140
141 impl<E, T: MakeRequest<Error = E>> MakeRequest for Retry502<T>
142 where
143 DefaultTransportError: From<E>,
144 {
145 type Body = T::Body;
146 type Error = DefaultTransportError;
147 async fn request(
148 &self,
149 request: Request<Full<Bytes>>,
150 ) -> Result<Response<Self::Body>, Self::Error> {
151 let sleep_time = Duration::from_millis(250);
152 let max_sleep = 10;
154 let mut slept = 0;
155
156 let resp = loop {
157 let resp = self.inner.request(request.clone()).await?;
158 if resp.status() != StatusCode::SERVICE_UNAVAILABLE {
159 break resp;
160 } else if slept >= max_sleep {
161 return Err(DefaultTransportError(DefaultTransportErrorImpl::RetryError));
162 } else {
163 slept += 1;
164 tokio::time::sleep(sleep_time).await
165 }
166 };
167 Ok(resp)
168 }
169 }
170
171 pub struct TrackCookies<T> {
177 inner: T,
178 cookie_jar: RwLock<CookieJar>,
179 }
180
181 impl<T> TrackCookies<T> {
182 pub fn wrapping(inner: T) -> Self {
184 Self {
185 inner,
186 cookie_jar: RwLock::new(CookieJar::new()),
187 }
188 }
189
190 fn is_cookie_valid_for_request(cookie: &Cookie, request: &Request<Full<Bytes>>) -> bool {
192 if let Some(domain) = cookie.domain() {
194 if !request.uri().host().unwrap_or("").ends_with(domain) {
195 return false;
196 }
197 }
198 if let Some(path) = cookie.path() {
200 if !request.uri().path().starts_with(path) {
201 return false;
202 }
203 }
204 if let Some(expires) = cookie.expires_datetime() {
206 if expires <= OffsetDateTime::now_utc() {
207 return false;
208 }
209 }
210 true
211 }
212 }
213 impl<T: MakeRequest> MakeRequest for TrackCookies<T> {
214 type Body = T::Body;
215 type Error = T::Error;
216 async fn request(
218 &self,
219 mut request: Request<Full<Bytes>>,
220 ) -> Result<Response<T::Body>, T::Error> {
221 let cookie_header = {
223 let jar = self.cookie_jar.read().await;
224 jar.iter()
225 .filter(|cookie| Self::is_cookie_valid_for_request(cookie, &request))
226 .map(|c| {
227 let (name, value) = c.name_value_trimmed();
228 format!("{name}={value}")
229 })
230 .collect::<Vec<_>>()
231 .join("; ")
232 };
233
234 if !cookie_header.is_empty() {
235 request
236 .headers_mut()
237 .insert(hyper::header::COOKIE, cookie_header.parse().unwrap());
238 }
239
240 let response = self.inner.request(request).await?;
241
242 let cookies = response
244 .headers()
245 .get_all(hyper::header::SET_COOKIE)
246 .iter()
247 .filter_map(|h| h.to_str().ok())
248 .filter_map(|s| Cookie::parse(s).ok())
249 .collect::<Vec<_>>();
250
251 if !cookies.is_empty() {
253 let mut jar = self.cookie_jar.write().await;
254 for cookie in cookies {
255 jar.add(cookie.into_owned());
256 }
257 }
258
259 Ok(response)
260 }
261 }
262
263 pub struct DefaultTransport(Retry502<TrackCookies<Http2Only>>);
270
271 impl Default for DefaultTransport {
272 fn default() -> Self {
273 Self(Retry502::wrapping(TrackCookies::wrapping(
274 Http2Only::default(),
275 )))
276 }
277 }
278
279 impl DefaultTransport {
280 pub fn new(force_ipv4: bool) -> Self {
283 Self(Retry502::wrapping(TrackCookies::wrapping(Http2Only::new(
284 force_ipv4,
285 ))))
286 }
287 }
288
289 #[allow(clippy::enum_variant_names)]
290 #[derive(Debug)]
291 enum DefaultTransportErrorImpl {
292 ConnectionError(std::io::Error),
293 RetryError,
294 HttpError(hyper::Error),
295 }
296
297 impl From<hyper::Error> for DefaultTransportErrorImpl {
298 fn from(value: hyper::Error) -> Self {
299 Self::HttpError(value)
300 }
301 }
302
303 pub struct DefaultTransportError(DefaultTransportErrorImpl);
305
306 impl Debug for DefaultTransportError {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 Debug::fmt(&self.0, f)
309 }
310 }
311
312 impl<T> From<T> for DefaultTransportError
313 where
314 T: Into<DefaultTransportErrorImpl>,
315 {
316 fn from(value: T) -> Self {
317 Self(value.into())
318 }
319 }
320
321 impl Error for DefaultTransportError {
322 fn source(&self) -> Option<&(dyn Error + 'static)> {
323 match &self.0 {
324 DefaultTransportErrorImpl::ConnectionError(e) => Some(e),
325 DefaultTransportErrorImpl::HttpError(e) => Some(e),
326 DefaultTransportErrorImpl::RetryError => None,
327 }
328 }
329 }
330
331 impl Display for DefaultTransportError {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 f.write_str(match self.0 {
334 DefaultTransportErrorImpl::ConnectionError(_) => "Failed to connect to endpoint",
335 DefaultTransportErrorImpl::HttpError(_) => "HTTP protocol error",
336 DefaultTransportErrorImpl::RetryError => {
337 "Server took to many tries to reply with a non-502 statuscode"
338 }
339 })
340 }
341 }
342
343 impl MakeRequest for DefaultTransport {
344 type Body = Incoming;
345 type Error = DefaultTransportError;
346 async fn request(
347 &self,
348 request: Request<Full<Bytes>>,
349 ) -> Result<Response<Self::Body>, Self::Error> {
350 self.0.request(request).await
351 }
352 }
353}
354
355#[cfg(feature = "default-client")]
356pub use default_impl::*;