use crate::error::Result;
use async_trait::async_trait;
use http::header::{HeaderName, HeaderValue};
use http::HeaderMap;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct HttpMiddlewareContext {
pub request_id: Option<String>,
pub url: String,
pub method: String,
pub attempt: u32,
pub metadata: Arc<parking_lot::RwLock<HashMap<String, String>>>,
}
impl HttpMiddlewareContext {
pub fn new(url: String, method: String) -> Self {
Self {
request_id: None,
url,
method,
attempt: 0,
metadata: Arc::new(parking_lot::RwLock::new(HashMap::new())),
}
}
pub fn set_metadata(&self, key: String, value: String) {
self.metadata.write().insert(key, value);
}
pub fn get_metadata(&self, key: &str) -> Option<String> {
self.metadata.read().get(key).cloned()
}
}
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub method: String,
pub url: String,
pub headers: HeaderMap,
pub body: Vec<u8>,
}
impl HttpRequest {
pub fn new(method: String, url: String, body: Vec<u8>) -> Self {
Self {
method,
url,
headers: HeaderMap::new(),
body,
}
}
pub fn add_header(&mut self, name: &str, value: &str) {
let header_name = HeaderName::from_bytes(name.as_bytes()).expect("Invalid header name");
let header_value = HeaderValue::from_str(value).expect("Invalid header value");
self.headers.insert(header_name, header_value);
}
pub fn get_header(&self, name: &str) -> Option<&str> {
let header_name = HeaderName::from_bytes(name.as_bytes()).ok()?;
self.headers.get(header_name)?.to_str().ok()
}
pub fn has_header(&self, name: &str) -> bool {
HeaderName::from_bytes(name.as_bytes())
.ok()
.and_then(|n| self.headers.get(n))
.is_some()
}
pub fn remove_header(&mut self, name: &str) -> Option<String> {
let header_name = HeaderName::from_bytes(name.as_bytes()).ok()?;
self.headers
.remove(header_name)
.and_then(|v| v.to_str().ok().map(|s| s.to_string()))
}
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: u16,
pub headers: HeaderMap,
pub body: Vec<u8>,
}
impl HttpResponse {
pub fn new(status: u16, body: Vec<u8>) -> Self {
Self {
status,
headers: HeaderMap::new(),
body,
}
}
pub fn with_headers(status: u16, headers: HeaderMap, body: Vec<u8>) -> Self {
Self {
status,
headers,
body,
}
}
pub fn add_header(&mut self, name: &str, value: &str) {
let header_name = HeaderName::from_bytes(name.as_bytes()).expect("Invalid header name");
let header_value = HeaderValue::from_str(value).expect("Invalid header value");
self.headers.insert(header_name, header_value);
}
pub fn get_header(&self, name: &str) -> Option<&str> {
let header_name = HeaderName::from_bytes(name.as_bytes()).ok()?;
self.headers.get(header_name)?.to_str().ok()
}
pub fn has_header(&self, name: &str) -> bool {
HeaderName::from_bytes(name.as_bytes())
.ok()
.and_then(|n| self.headers.get(n))
.is_some()
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
pub fn is_client_error(&self) -> bool {
(400..500).contains(&self.status)
}
pub fn is_server_error(&self) -> bool {
(500..600).contains(&self.status)
}
}
#[async_trait]
pub trait HttpMiddleware: Send + Sync {
async fn on_request(
&self,
request: &mut HttpRequest,
context: &HttpMiddlewareContext,
) -> Result<()> {
let _ = (request, context);
Ok(())
}
async fn on_response(
&self,
response: &mut HttpResponse,
context: &HttpMiddlewareContext,
) -> Result<()> {
let _ = (response, context);
Ok(())
}
async fn on_error(
&self,
error: &crate::error::Error,
context: &HttpMiddlewareContext,
) -> Result<()> {
let _ = (error, context);
Ok(())
}
fn priority(&self) -> i32 {
50 }
async fn should_execute(&self, _context: &HttpMiddlewareContext) -> bool {
true
}
}
pub struct HttpMiddlewareChain {
middlewares: Vec<Arc<dyn HttpMiddleware>>,
}
impl HttpMiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn add(&mut self, middleware: Arc<dyn HttpMiddleware>) {
self.middlewares.push(middleware);
self.middlewares.sort_by_key(|m| m.priority());
}
pub async fn process_request(
&self,
request: &mut HttpRequest,
context: &HttpMiddlewareContext,
) -> Result<()> {
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_request(request, context).await {
self.handle_error(&e, context).await;
return Err(e);
}
}
}
Ok(())
}
pub async fn process_response(
&self,
response: &mut HttpResponse,
context: &HttpMiddlewareContext,
) -> Result<()> {
for middleware in self.middlewares.iter().rev() {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_response(response, context).await {
self.handle_error(&e, context).await;
return Err(e);
}
}
}
Ok(())
}
async fn handle_error(&self, error: &crate::error::Error, context: &HttpMiddlewareContext) {
for middleware in &self.middlewares {
if let Err(e) = middleware.on_error(error, context).await {
tracing::error!(
"Error in middleware on_error hook: {} (original error: {})",
e,
error
);
}
}
}
pub async fn handle_transport_error(
&self,
error: &crate::error::Error,
context: &HttpMiddlewareContext,
) {
self.handle_error(error, context).await;
}
}
impl Default for HttpMiddlewareChain {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for HttpMiddlewareChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpMiddlewareChain")
.field("count", &self.middlewares.len())
.finish()
}
}