use std::{collections::HashSet, sync::Arc, time::Duration};
#[cfg(any(feature = "record", feature = "replay"))]
use crate::vendor::rvcr::{VCRMiddleware, VCRMode};
use async_compression::tokio::write::ZstdEncoder;
use async_trait::async_trait;
use http::Extensions;
use reqwest::{header::HeaderValue, header::CONTENT_ENCODING, Body, Request, Response, StatusCode};
use reqwest_middleware::{Middleware, Next};
use reqwest_retry::{
default_on_request_failure,
policies::{ExponentialBackoff, ExponentialBackoffTimed},
RetryTransientMiddleware, Retryable, RetryableStrategy,
};
#[cfg(any(feature = "record", feature = "replay"))]
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
#[cfg(not(any(feature = "record", feature = "replay")))]
use tokio_util::io::ReaderStream;
use tracing::{debug, error};
const ZSTD_BUFFER_SIZE: usize = 4096;
const ZSTD_MIN_BODY_SIZE: usize = 512;
pub struct HeaderDeduplicatorMiddleware;
#[async_trait::async_trait]
impl Middleware for HeaderDeduplicatorMiddleware {
async fn handle(
&self,
mut req: reqwest::Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<reqwest::Response, reqwest_middleware::Error> {
let headers = req.headers_mut();
let mut seen = HashSet::new();
let mut to_remove = Vec::new();
for (name, _) in headers.iter() {
if !seen.insert(name) {
to_remove.push(name.clone());
}
}
for name in to_remove {
let values: Vec<_> = headers.get_all(&name).iter().cloned().collect();
headers.remove(&name);
if let Some(first_value) = values.first() {
headers.insert(&name, first_value.clone());
}
}
next.run(req, extensions).await
}
}
pub struct LoggingMiddleware;
#[async_trait::async_trait]
impl Middleware for LoggingMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
debug!("Request sent: {:?}", req);
let res = next.run(req, extensions).await;
debug!("Response received: {:?}", res);
res
}
}
pub struct Retry500;
impl RetryableStrategy for Retry500 {
fn handle(
&self,
res: &Result<reqwest::Response, reqwest_middleware::Error>,
) -> Option<Retryable> {
let unrecoverable_codes = [
StatusCode::from_u16(400).unwrap(),
StatusCode::from_u16(401).unwrap(),
StatusCode::from_u16(403).unwrap(),
StatusCode::from_u16(404).unwrap(),
StatusCode::from_u16(405).unwrap(),
StatusCode::from_u16(406).unwrap(),
StatusCode::from_u16(407).unwrap(),
StatusCode::from_u16(408).unwrap(),
StatusCode::from_u16(409).unwrap(),
StatusCode::from_u16(410).unwrap(),
StatusCode::from_u16(411).unwrap(),
StatusCode::from_u16(412).unwrap(),
StatusCode::from_u16(413).unwrap(),
StatusCode::from_u16(414).unwrap(),
StatusCode::from_u16(415).unwrap(),
StatusCode::from_u16(416).unwrap(),
StatusCode::from_u16(417).unwrap(),
StatusCode::from_u16(418).unwrap(),
StatusCode::from_u16(421).unwrap(),
StatusCode::from_u16(501).unwrap(),
StatusCode::from_u16(505).unwrap(),
StatusCode::from_u16(506).unwrap(),
StatusCode::from_u16(510).unwrap(),
];
let transient_codes = [
StatusCode::from_u16(500).unwrap(),
StatusCode::from_u16(502).unwrap(),
StatusCode::from_u16(503).unwrap(),
StatusCode::from_u16(504).unwrap(),
];
match res {
Ok(success) if transient_codes.contains(&success.status()) => {
debug!(
"Retrying request due to temporary API outage: {}",
success.status()
);
Some(Retryable::Transient)
}
Ok(success) if unrecoverable_codes.contains(&success.status()) => {
debug!(
"Request failed with fatal client error: {}",
success.status()
);
Some(Retryable::Fatal)
}
Ok(_success) => None,
Err(error) => {
debug!("Request failed with network error: {}", error);
default_on_request_failure(error)
}
}
}
}
pub fn retry_client(
max_duration: Option<Duration>,
) -> RetryTransientMiddleware<ExponentialBackoffTimed, Retry500> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(Duration::from_secs(1), Duration::from_secs(8))
.build_with_total_retry_duration(max_duration.unwrap_or(Duration::from_secs(60)));
RetryTransientMiddleware::new_with_policy_and_strategy(retry_policy, Retry500)
}
#[cfg(any(feature = "record", feature = "replay"))]
pub fn vcr_middleware(bundle: std::path::PathBuf) -> VCRMiddleware {
let mut vcr = VCRMiddleware::try_from(bundle.clone()).unwrap();
vcr = vcr.with_modify_request(|req| {
req.headers.insert(
"authorization".to_string(),
vec!["Bearer REDACTED_TOKEN".to_string()],
);
req.headers.insert(
"user-agent".to_string(),
vec!["OpenAPI-Generator/v0.0.0/rust".to_string()],
);
});
vcr = vcr.with_modify_response(|res| {
if res
.headers
.get("content-type")
.and_then(|values| values.first())
.map(|v| v.contains("application/octet-stream"))
.unwrap_or(false)
{
res.body.encoding = None; }
});
#[cfg(feature = "record")]
{
vcr = vcr.with_mode(VCRMode::Record);
}
vcr
}
#[derive(Debug)]
pub struct ZstdRequestCompressionMiddleware;
#[async_trait]
impl Middleware for ZstdRequestCompressionMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
if let Some(bytes) = req
.body()
.and_then(|b| b.as_bytes())
.filter(|b| b.len() >= ZSTD_MIN_BODY_SIZE)
{
let (method, url, headers, version) = (
req.method().clone(),
req.url().clone(),
req.headers().clone(),
req.version(),
);
let mut new_req = Request::new(method, url);
*new_req.headers_mut() = headers;
*new_req.version_mut() = version;
let (writer, reader) = tokio::io::duplex(ZSTD_BUFFER_SIZE);
let body_arc = Arc::new(bytes.to_vec());
let body_clone = Arc::clone(&body_arc);
tokio::spawn(async move {
let mut encoder = ZstdEncoder::new(writer);
if let Err(error) = encoder.write_all(&body_clone).await {
error!("Failed to compress body: {}", error);
}
let _ = encoder.shutdown().await;
});
#[cfg(not(any(feature = "record", feature = "replay")))]
{
new_req
.body_mut()
.replace(Body::wrap_stream(ReaderStream::new(reader)));
}
#[cfg(any(feature = "record", feature = "replay"))]
{
let mut buf = Vec::new();
let reader = Arc::new(tokio::sync::Mutex::new(reader));
let mut reader_lock = reader.lock().await;
reader_lock.read_to_end(&mut buf).await.unwrap();
new_req.body_mut().replace(Body::from(buf));
}
new_req
.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd"));
return next.run(new_req, extensions).await;
}
next.run(req, extensions).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use std::time::{Duration, Instant};
use async_compression::tokio::bufread::ZstdDecoder;
use reqwest::header::{HeaderMap, HeaderValue};
use tokio::io::AsyncReadExt;
use tokio::io::BufReader;
use wiremock::{
matchers::{header, method},
Mock, MockServer, ResponseTemplate,
};
#[tokio::test]
async fn test_header_deduplicator() {
let mut headers = HeaderMap::new();
headers.append("Authorization", HeaderValue::from_static("Bearer firstkey"));
headers.append(
"Authorization",
HeaderValue::from_static("Bearer secondkey"),
);
headers.append("Content-Type", HeaderValue::from_static("application/json"));
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(header("Authorization", "Bearer secondkey"))
.respond_with(ResponseTemplate::new(400))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(header("Authorization", "Bearer firstkey"))
.and(header("Content-Type", "application/json"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let client = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.build()
.expect("Could not build client"),
)
.with(HeaderDeduplicatorMiddleware)
.build();
let mut request = client.get(mock_server.uri()).build().unwrap();
*request.headers_mut() = headers; let response = client.execute(request).await.unwrap();
assert_ne!(response.status(), 404); assert_ne!(response.status(), 400); assert_eq!(response.status(), 200); }
#[tokio::test]
async fn test_retry_policy_on_500() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(500))
.mount(&mock_server)
.await;
let client = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.build()
.expect("Could not build client"),
)
.with(retry_client(Some(Duration::from_secs(15))))
.build();
let request = client.get(mock_server.uri()).build().unwrap();
let start = Instant::now();
client.execute(request).await.unwrap();
let elapsed = start.elapsed();
let num_requests = mock_server.received_requests().await.unwrap().len();
assert!(num_requests > 3);
let lower_bound = Duration::new(15, 0);
let upper_bound = Duration::new(25, 0);
assert!(elapsed >= lower_bound && elapsed <= upper_bound);
}
#[tokio::test]
async fn test_retry_policy_on_400() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(400))
.mount(&mock_server)
.await;
let client = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.build()
.expect("Could not build client"),
)
.with(retry_client(None))
.build();
let request = client.get(mock_server.uri()).build().unwrap();
client.execute(request).await.unwrap();
let num_retries = mock_server.received_requests().await.unwrap().len();
assert_eq!(num_retries, 1);
}
#[tokio::test]
async fn test_zstd_request_compression() {
let original_body = "A".repeat(ZSTD_MIN_BODY_SIZE);
let mock_server = MockServer::start().await;
let _mock = Mock::given(method("POST"))
.and(header("Content-Encoding", "zstd"))
.respond_with(ResponseTemplate::new(200))
.expect(1)
.mount_as_scoped(&mock_server)
.await;
let client = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.build()
.expect("Could not build client"),
)
.with(ZstdRequestCompressionMiddleware)
.build();
let request = client
.post(mock_server.uri())
.body(original_body.to_string())
.build()
.unwrap();
let response = client.execute(request).await.unwrap();
assert_eq!(response.status(), 200);
let received_request = mock_server.received_requests().await.unwrap();
assert_eq!(received_request.len(), 1);
let compressed_body = &received_request[0].body;
let cursor = Cursor::new(compressed_body.clone());
let buf_reader = BufReader::new(cursor);
let mut decoder = ZstdDecoder::new(buf_reader);
let mut decompressed_body = Vec::new();
decoder.read_to_end(&mut decompressed_body).await.unwrap();
assert_eq!(
decompressed_body,
original_body.as_bytes(),
"Decompressed body does not match original input."
);
}
}