use std::{
future::Future,
sync::Arc,
task::{Context, Poll},
};
use http::{HeaderMap, Request, Response, StatusCode};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use crate::{Body, Error};
pub trait BeforeRequestHook: Send + Sync {
fn on_request(&self, request: &mut Request<Body>) -> Result<(), Error>;
}
pub trait AfterResponseHook: Send + Sync {
fn on_response(&self, status: StatusCode, headers: &HeaderMap) -> Result<(), Error>;
}
pub trait OnErrorHook: Send + Sync {
fn on_error(&self, error: &Error);
}
#[derive(Clone, Default)]
pub struct Hooks {
pub(crate) before_request: Vec<Arc<dyn BeforeRequestHook>>,
pub(crate) after_response: Vec<Arc<dyn AfterResponseHook>>,
pub(crate) on_error: Vec<Arc<dyn OnErrorHook>>,
}
#[derive(Default)]
pub struct HooksBuilder {
hooks: Hooks,
}
impl Hooks {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn builder() -> HooksBuilder {
HooksBuilder::default()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.before_request.is_empty() && self.after_response.is_empty() && self.on_error.is_empty()
}
pub(crate) fn run_before_request(&self, request: &mut Request<Body>) -> Result<(), Error> {
for hook in &self.before_request {
hook.on_request(request)?;
}
Ok(())
}
pub(crate) fn run_after_response(
&self,
status: StatusCode,
headers: &HeaderMap,
) -> Result<(), Error> {
for hook in &self.after_response {
hook.on_response(status, headers)?;
}
Ok(())
}
pub(crate) fn run_on_error(&self, error: &Error) {
for hook in &self.on_error {
hook.on_error(error);
}
}
}
impl HooksBuilder {
pub fn before_request<H>(mut self, hook: Arc<H>) -> Self
where
H: BeforeRequestHook + 'static,
{
self.hooks.before_request.push(hook);
self
}
pub fn after_response<H>(mut self, hook: Arc<H>) -> Self
where
H: AfterResponseHook + 'static,
{
self.hooks.after_response.push(hook);
self
}
pub fn on_error<H>(mut self, hook: Arc<H>) -> Self
where
H: OnErrorHook + 'static,
{
self.hooks.on_error.push(hook);
self
}
pub fn build(self) -> Hooks {
self.hooks
}
}
#[derive(Clone, Default)]
pub struct LoggingHook {
log_headers: bool,
}
impl LoggingHook {
pub fn new() -> Self {
Self::default()
}
pub fn with_headers(mut self) -> Self {
self.log_headers = true;
self
}
}
impl BeforeRequestHook for LoggingHook {
fn on_request(&self, request: &mut Request<Body>) -> Result<(), Error> {
#[cfg(feature = "tracing")]
{
tracing::debug!(
method = %request.method(),
uri = %request.uri(),
"Sending request"
);
if self.log_headers {
for (name, value) in request.headers() {
tracing::trace!(header = %name, value = ?value);
}
}
}
#[cfg(not(feature = "tracing"))]
{
let _ = (request, self.log_headers);
}
Ok(())
}
}
impl AfterResponseHook for LoggingHook {
fn on_response(&self, status: StatusCode, headers: &HeaderMap) -> Result<(), Error> {
#[cfg(feature = "tracing")]
{
tracing::debug!(
status = %status,
"Received response"
);
if self.log_headers {
for (name, value) in headers {
tracing::trace!(header = %name, value = ?value);
}
}
}
#[cfg(not(feature = "tracing"))]
{
let _ = (status, headers, self.log_headers);
}
Ok(())
}
}
pub struct HeaderInjectionHook {
headers: HeaderMap,
}
impl HeaderInjectionHook {
pub fn new(headers: HeaderMap) -> Self {
Self { headers }
}
pub fn single(name: http::header::HeaderName, value: http::HeaderValue) -> Self {
let mut headers = HeaderMap::new();
headers.insert(name, value);
Self { headers }
}
}
impl BeforeRequestHook for HeaderInjectionHook {
fn on_request(&self, request: &mut Request<Body>) -> Result<(), Error> {
for (name, value) in &self.headers {
request.headers_mut().insert(name.clone(), value.clone());
}
Ok(())
}
}
pub struct RequestIdHook {
header_name: http::header::HeaderName,
}
impl Default for RequestIdHook {
fn default() -> Self {
Self {
header_name: http::header::HeaderName::from_static("x-request-id"),
}
}
}
impl RequestIdHook {
pub fn new() -> Self {
Self::default()
}
pub fn with_header_name(mut self, name: http::header::HeaderName) -> Self {
self.header_name = name;
self
}
}
impl BeforeRequestHook for RequestIdHook {
fn on_request(&self, request: &mut Request<Body>) -> Result<(), Error> {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let id = format!("{:x}", timestamp);
if let Ok(value) = http::HeaderValue::from_str(&id) {
request
.headers_mut()
.insert(self.header_name.clone(), value);
}
Ok(())
}
}
#[derive(Clone)]
pub struct HooksLayer {
hooks: Arc<Hooks>,
}
impl HooksLayer {
pub fn new(hooks: Hooks) -> Self {
Self {
hooks: Arc::new(hooks),
}
}
}
impl<S> Layer<S> for HooksLayer {
type Service = HooksService<S>;
fn layer(&self, inner: S) -> Self::Service {
HooksService {
inner,
hooks: self.hooks.clone(),
}
}
}
#[derive(Clone)]
pub struct HooksService<S> {
inner: S,
hooks: Arc<Hooks>,
}
pin_project! {
pub struct HooksFuture<Fut> {
hooks: Arc<Hooks>,
#[pin]
state: HooksFutureState<Fut>,
}
}
pin_project! {
#[project = HooksFutureStateProj]
enum HooksFutureState<Fut> {
Error {
error: Option<crate::error::BoxError>,
},
Response {
#[pin]
future: Fut,
},
}
}
impl<S, ResBody> Service<Request<Body>> for HooksService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Into<crate::error::BoxError> + Send,
S::Future: Send + 'static,
ResBody: Send + 'static,
{
type Response = Response<ResBody>;
type Error = crate::error::BoxError;
type Future = HooksFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let hooks = self.hooks.clone();
let mut inner = self.inner.clone();
if let Err(e) = hooks.run_before_request(&mut req) {
hooks.run_on_error(&e);
return HooksFuture {
hooks,
state: HooksFutureState::Error {
error: Some(e.into()),
},
};
}
HooksFuture {
hooks,
state: HooksFutureState::Response {
future: inner.call(req),
},
}
}
}
impl<Fut, ResBody, Err> Future for HooksFuture<Fut>
where
Fut: Future<Output = Result<Response<ResBody>, Err>>,
Err: Into<crate::error::BoxError>,
{
type Output = Result<Response<ResBody>, crate::error::BoxError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
match this.state.as_mut().project() {
HooksFutureStateProj::Error { error } => {
let error = error
.take()
.expect("HooksFuture::Error polled after completion");
Poll::Ready(Err(error))
}
HooksFutureStateProj::Response { future } => match future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(response)) => {
if let Err(e) = this
.hooks
.run_after_response(response.status(), response.headers())
{
this.hooks.run_on_error(&e);
Poll::Ready(Err(e.into()))
} else {
Poll::Ready(Ok(response))
}
}
Poll::Ready(Err(error)) => {
let boxed_error: crate::error::BoxError = error.into();
let error = Error::request(boxed_error);
this.hooks.run_on_error(&error);
Poll::Ready(Err(error.into()))
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hooks_builder() {
let hooks = Hooks::builder().build();
assert!(hooks.is_empty());
}
#[test]
fn test_logging_hook() {
let hook = LoggingHook::new().with_headers();
assert!(hook.log_headers);
}
#[test]
fn test_header_injection_hook() {
use http::header;
let hook = HeaderInjectionHook::single(
header::AUTHORIZATION,
header::HeaderValue::from_static("Bearer token"),
);
let mut req = Request::builder()
.uri("https://example.com")
.body(Body::empty())
.unwrap();
hook.on_request(&mut req).unwrap();
assert!(req.headers().contains_key(header::AUTHORIZATION));
}
#[test]
fn test_request_id_hook() {
let hook = RequestIdHook::new();
let mut req = Request::builder()
.uri("https://example.com")
.body(Body::empty())
.unwrap();
hook.on_request(&mut req).unwrap();
assert!(req.headers().contains_key("x-request-id"));
}
}